示例#1
0
    def __init__(self, configuration):
        """
        Initialize the internal structures of this agent.

        :param configuration: a dictionary representing the configuration file
        """

        super(ConversationalSingleAgent, self).__init__()

        self.configuration = configuration

        # There is only one agent in this setting
        self.agent_id = 0

        # dialogue statistics
        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0
        self.total_dialogue_turns = 0

        # Default meta-parameter values
        self.minibatch_length = 500
        self.train_interval = 50
        self.train_epochs = 3

        # True values here would imply some default modules
        self.USE_USR_SIMULATOR = False
        self.USER_SIMULATOR_NLU = False
        self.USER_SIMULATOR_NLG = False
        self.USE_NLG = False
        self.USE_SPEECH = False
        self.USER_HAS_INITIATIVE = True
        self.SAVE_LOG = True
        self.SAVE_INTERVAL = 10000

        # The dialogue will terminate after MAX_TURNS (this agent will issue
        # a bye() dialogue act.
        self.MAX_TURNS = 15

        self.dialogue_turn = -1
        self.ontology = None
        self.database = None
        self.domain = None
        self.global_args = {}
        self.dialogue_manager = None
        self.user_model = None
        self.user_simulator = None
        self.user_simulator_args = {}
        self.nlu = None
        self.nlg = None

        self.agent_role = None
        self.agent_goal = None
        self.goal_generator = None

        self.curr_state = None
        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

        self.user_model = UserModel()

        self.recorder = DialogueEpisodeRecorder()

        # TODO: Handle this properly - get reward function type from config
        self.reward_func = SlotFillingReward()
        # self.reward_func = SlotFillingGoalAdvancementReward()

        if self.configuration:
            # Error checks for options the config must have
            if not self.configuration['GENERAL']:
                raise ValueError('Cannot run Plato without GENERAL settings!')

            elif not self.configuration['GENERAL']['interaction_mode']:
                raise ValueError('Cannot run Plato without an '
                                 'interaction mode!')

            elif not self.configuration['DIALOGUE']:
                raise ValueError('Cannot run Plato without DIALOGUE settings!')

            elif not self.configuration['AGENT_0']:
                raise ValueError('Cannot run Plato without at least '
                                 'one agent!')

            # dialogue domain self.settings
            if 'DIALOGUE' in self.configuration and \
                    self.configuration['DIALOGUE']:
                if 'initiative' in self.configuration['DIALOGUE']:
                    self.USER_HAS_INITIATIVE = bool(
                        self.configuration['DIALOGUE']['initiative'] == 'user'
                    )
                    self.user_simulator_args['us_has_initiative'] = \
                        self.USER_HAS_INITIATIVE

                if self.configuration['DIALOGUE']['domain']:
                    self.domain = self.configuration['DIALOGUE']['domain']

                if 'ontology_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['ontology_path']
                    ):
                        self.ontology = ontology.Ontology(
                            self.configuration['DIALOGUE']['ontology_path']
                        )
                    else:
                        raise FileNotFoundError(
                            'domain file %s not found' %
                            self.configuration['DIALOGUE']['ontology_path'])

                # Alternatively, look at global_arguments for ontology path
                elif 'global_arguments' in self.configuration['GENERAL'] \
                        and 'ontology' in \
                        self.configuration['GENERAL']['global_arguments']:
                    if os.path.isfile(
                            self.configuration['GENERAL'][
                                'global_arguments']['ontology']
                    ):
                        self.ontology = ontology.Ontology(
                            self.configuration['GENERAL'][
                                'global_arguments']['ontology']
                        )
                    else:
                        raise FileNotFoundError(
                            'domain file %s not found' %
                            self.configuration['GENERAL'][
                                'global_arguments']['ontology'])

                if 'db_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['db_path']
                    ):
                        if 'db_type' in self.configuration['DIALOGUE']:
                            if self.configuration['DIALOGUE']['db_type'] == \
                                    'sql':
                                self.database = database.SQLDataBase(
                                    self.configuration['DIALOGUE']['db_path']
                                )
                            else:
                                self.database = database.DataBase(
                                    self.configuration['DIALOGUE']['db_path']
                                )
                        else:
                            # Default to SQL
                            self.database = database.SQLDataBase(
                                self.configuration['DIALOGUE']['db_path']
                            )
                    else:
                        raise FileNotFoundError(
                            'Database file %s not found' %
                            self.configuration['DIALOGUE']['db_path']
                        )

                # Alternatively, look at global arguments for db path
                elif 'global_arguments' in self.configuration['GENERAL'] \
                        and 'database' in \
                        self.configuration['GENERAL']['global_arguments']:
                    if os.path.isfile(
                            self.configuration['GENERAL'][
                                'global_arguments']['database']
                    ):
                        self.database = database.DataBase(
                            self.configuration['GENERAL'][
                                'global_arguments']['database']
                        )
                    else:
                        raise FileNotFoundError(
                            'domain file %s not found' %
                            self.configuration['GENERAL'][
                                'global_arguments']['ontology'])

                if 'goals_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['goals_path']
                    ):
                        self.goals_path = \
                            self.configuration['DIALOGUE']['goals_path']
                    else:
                        raise FileNotFoundError(
                            'Goals file %s not found' %
                            self.configuration['DIALOGUE']['goals_path']
                        )

            # General settings
            if 'GENERAL' in self.configuration and \
                    self.configuration['GENERAL']:
                if 'global_arguments' in self.configuration['GENERAL']:
                    self.global_args = \
                        self.configuration['GENERAL']['global_arguments']
                    
                if 'experience_logs' in self.configuration['GENERAL']:
                    dialogues_path = None
                    if 'path' in \
                            self.configuration['GENERAL']['experience_logs']:
                        dialogues_path = \
                            self.configuration['GENERAL'][
                                'experience_logs']['path']

                    if 'load' in \
                            self.configuration['GENERAL']['experience_logs'] \
                            and bool(self.configuration['GENERAL'][
                                         'experience_logs']['load']):
                        if dialogues_path and os.path.isfile(dialogues_path):
                            self.recorder.load(dialogues_path)
                        else:
                            raise FileNotFoundError(
                                'dialogue Log file %s not found (did you '
                                'provide one?)' % dialogues_path)

                    if 'save' in \
                            self.configuration['GENERAL']['experience_logs']:
                        self.recorder.set_path(dialogues_path)
                        self.SAVE_LOG = bool(
                            self.configuration['GENERAL'][
                                'experience_logs']['save']
                        )

                if self.configuration['GENERAL']['interaction_mode'] == \
                        'simulation':
                    self.USE_USR_SIMULATOR = True

                elif self.configuration['GENERAL']['interaction_mode'] == \
                        'speech':
                    self.USE_SPEECH = True
                    self.asr = speech_rec.Recognizer()

            # Agent Settings

            # Retrieve agent role
            if 'role' in self.configuration['AGENT_0']:
                self.agent_role = self.configuration['AGENT_0']['role']
            else:
                raise ValueError(
                    'agent: No role assigned for agent {0} in '
                    'config!'.format(self.agent_id)
                )

            if self.agent_role == 'user':
                if self.ontology and self.database:
                    self.goal_generator = GoalGenerator({
                        'ontology': self.ontology,
                        'database': self.database
                    })
                else:
                    raise ValueError(
                        'Conversational Single Agent (user): Cannot generate '
                        'goal without ontology and database.'
                    )
                
            # Retrieve agent parameters
            if 'max_turns' in self.configuration['AGENT_0']:
                self.MAX_TURNS = self.configuration['AGENT_0']['max_turns']
                
            if 'train_interval' in self.configuration['AGENT_0']:
                self.train_interval = \
                    self.configuration['AGENT_0']['train_interval']
                
            if 'train_minibatch' in self.configuration['AGENT_0']:
                self.minibatch_length = \
                    self.configuration['AGENT_0']['train_minibatch']
                
            if 'train_epochs' in self.configuration['AGENT_0']:
                self.train_epochs = \
                    self.configuration['AGENT_0']['train_epochs']

            if 'save_interval' in self.configuration['AGENT_0']:
                self.SAVE_INTERVAL = \
                    self.configuration['AGENT_0']['save_interval']
            
            # usr Simulator
            # Check for specific simulator self.settings, otherwise
            # default to agenda
            if 'USER_SIMULATOR' in self.configuration['AGENT_0']:
                # Agent 0 simulator configuration
                if 'package' in \
                    self.configuration['AGENT_0']['USER_SIMULATOR'] and \
                        'class' in \
                        self.configuration['AGENT_0']['USER_SIMULATOR']:
                    if 'arguments' in \
                            self.configuration['AGENT_0']['USER_SIMULATOR']:
                        self.user_simulator_args =\
                            self.configuration[
                                'AGENT_0']['USER_SIMULATOR']['arguments']

                    self.user_simulator_args.update(self.global_args)
                    
                    self.user_simulator = \
                        ConversationalGenericAgent.load_module(
                            self.configuration['AGENT_0']['USER_SIMULATOR'][
                                'package'],
                            self.configuration['AGENT_0']['USER_SIMULATOR'][
                                'class'],
                            self.user_simulator_args
                        )

                    if hasattr(self.user_simulator, 'nlu'):
                        self.USER_SIMULATOR_NLU = self.user_simulator.nlu
                    if hasattr(self.user_simulator, 'nlg'):
                        self.USER_SIMULATOR_NLG = self.user_simulator.nlg
                else:
                    # Fallback to agenda based simulator with default settings
                    self.user_simulator = AgendaBasedUS(
                        self.user_simulator_args
                    )

            # NLU Settings
            if 'NLU' in self.configuration['AGENT_0']:
                nlu_args = {}
                if 'package' in self.configuration['AGENT_0']['NLU'] and \
                        'class' in self.configuration['AGENT_0']['NLU']:
                    if 'arguments' in \
                            self.configuration['AGENT_0']['NLU']:
                        nlu_args = \
                            self.configuration['AGENT_0']['NLU']['arguments']

                    nlu_args.update(self.global_args)

                    self.nlu = \
                        ConversationalGenericAgent.load_module(
                            self.configuration['AGENT_0']['NLU'][
                                'package'],
                            self.configuration['AGENT_0']['NLU'][
                                'class'],
                            nlu_args
                        )
                    
            # DM Settings
            if 'DM' in self.configuration['AGENT_0']:
                dm_args = dict(
                    zip(
                        ['settings', 'ontology', 'database', 'domain',
                         'agent_id',
                         'agent_role'],
                        [self.configuration,
                         self.ontology,
                         self.database,
                         self.domain,
                         self.agent_id,
                         self.agent_role
                         ]
                    )
                )

                if 'package' in self.configuration['AGENT_0']['DM'] and \
                        'class' in self.configuration['AGENT_0']['DM']:
                    if 'arguments' in \
                            self.configuration['AGENT_0']['DM']:
                        dm_args.update(
                            self.configuration['AGENT_0']['DM']['arguments']
                        )

                    dm_args.update(self.global_args)

                    self.dialogue_manager = \
                        ConversationalGenericAgent.load_module(
                            self.configuration['AGENT_0']['DM'][
                                'package'],
                            self.configuration['AGENT_0']['DM'][
                                'class'],
                            dm_args
                        )

            # NLG Settings
            if 'NLG' in self.configuration['AGENT_0']:
                nlg_args = {}
                if 'package' in self.configuration['AGENT_0']['NLG'] and \
                        'class' in self.configuration['AGENT_0']['NLG']:
                    if 'arguments' in \
                            self.configuration['AGENT_0']['NLG']:
                        nlg_args = \
                            self.configuration['AGENT_0']['NLG']['arguments']

                    nlg_args.update(self.global_args)

                    self.nlg = \
                        ConversationalGenericAgent.load_module(
                            self.configuration['AGENT_0']['NLG'][
                                'package'],
                            self.configuration['AGENT_0']['NLG'][
                                'class'],
                            nlg_args
                        )

                if self.nlg:
                    self.USE_NLG = True

        # True if at least one module is training
        self.IS_TRAINING = self.nlu and self.nlu.training or \
            self.dialogue_manager and self.dialogue_manager.training or \
            self.nlg and self.nlg.training
示例#2
0
class ConversationalSingleAgent(ConversationalAgent):
    """
    Essentially the dialogue system. Will be able to interact with:

    - Simulated Users via:
        - dialogue Acts
        - Text

    - Human Users via:
        - Text
        - Speech
        - Online crowd?

    - parser
    """

    def __init__(self, configuration):
        """
        Initialize the internal structures of this agent.

        :param configuration: a dictionary representing the configuration file
        """

        super(ConversationalSingleAgent, self).__init__()

        self.configuration = configuration

        # There is only one agent in this setting
        self.agent_id = 0

        # dialogue statistics
        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0
        self.total_dialogue_turns = 0

        # Default meta-parameter values
        self.minibatch_length = 500
        self.train_interval = 50
        self.train_epochs = 3

        # True values here would imply some default modules
        self.USE_USR_SIMULATOR = False
        self.USER_SIMULATOR_NLU = False
        self.USER_SIMULATOR_NLG = False
        self.USE_NLG = False
        self.USE_SPEECH = False
        self.USER_HAS_INITIATIVE = True
        self.SAVE_LOG = True
        self.SAVE_INTERVAL = 10000

        # The dialogue will terminate after MAX_TURNS (this agent will issue
        # a bye() dialogue act.
        self.MAX_TURNS = 15

        self.dialogue_turn = -1
        self.ontology = None
        self.database = None
        self.domain = None
        self.global_args = {}
        self.dialogue_manager = None
        self.user_model = None
        self.user_simulator = None
        self.user_simulator_args = {}
        self.nlu = None
        self.nlg = None

        self.agent_role = None
        self.agent_goal = None
        self.goal_generator = None

        self.curr_state = None
        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

        self.user_model = UserModel()

        self.recorder = DialogueEpisodeRecorder()

        # TODO: Handle this properly - get reward function type from config
        self.reward_func = SlotFillingReward()
        # self.reward_func = SlotFillingGoalAdvancementReward()

        if self.configuration:
            # Error checks for options the config must have
            if not self.configuration['GENERAL']:
                raise ValueError('Cannot run Plato without GENERAL settings!')

            elif not self.configuration['GENERAL']['interaction_mode']:
                raise ValueError('Cannot run Plato without an '
                                 'interaction mode!')

            elif not self.configuration['DIALOGUE']:
                raise ValueError('Cannot run Plato without DIALOGUE settings!')

            elif not self.configuration['AGENT_0']:
                raise ValueError('Cannot run Plato without at least '
                                 'one agent!')

            # dialogue domain self.settings
            if 'DIALOGUE' in self.configuration and \
                    self.configuration['DIALOGUE']:
                if 'initiative' in self.configuration['DIALOGUE']:
                    self.USER_HAS_INITIATIVE = bool(
                        self.configuration['DIALOGUE']['initiative'] == 'user'
                    )
                    self.user_simulator_args['us_has_initiative'] = \
                        self.USER_HAS_INITIATIVE

                if self.configuration['DIALOGUE']['domain']:
                    self.domain = self.configuration['DIALOGUE']['domain']

                if 'ontology_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['ontology_path']
                    ):
                        self.ontology = ontology.Ontology(
                            self.configuration['DIALOGUE']['ontology_path']
                        )
                    else:
                        raise FileNotFoundError(
                            'domain file %s not found' %
                            self.configuration['DIALOGUE']['ontology_path'])

                # Alternatively, look at global_arguments for ontology path
                elif 'global_arguments' in self.configuration['GENERAL'] \
                        and 'ontology' in \
                        self.configuration['GENERAL']['global_arguments']:
                    if os.path.isfile(
                            self.configuration['GENERAL'][
                                'global_arguments']['ontology']
                    ):
                        self.ontology = ontology.Ontology(
                            self.configuration['GENERAL'][
                                'global_arguments']['ontology']
                        )
                    else:
                        raise FileNotFoundError(
                            'domain file %s not found' %
                            self.configuration['GENERAL'][
                                'global_arguments']['ontology'])

                if 'db_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['db_path']
                    ):
                        if 'db_type' in self.configuration['DIALOGUE']:
                            if self.configuration['DIALOGUE']['db_type'] == \
                                    'sql':
                                self.database = database.SQLDataBase(
                                    self.configuration['DIALOGUE']['db_path']
                                )
                            else:
                                self.database = database.DataBase(
                                    self.configuration['DIALOGUE']['db_path']
                                )
                        else:
                            # Default to SQL
                            self.database = database.SQLDataBase(
                                self.configuration['DIALOGUE']['db_path']
                            )
                    else:
                        raise FileNotFoundError(
                            'Database file %s not found' %
                            self.configuration['DIALOGUE']['db_path']
                        )

                # Alternatively, look at global arguments for db path
                elif 'global_arguments' in self.configuration['GENERAL'] \
                        and 'database' in \
                        self.configuration['GENERAL']['global_arguments']:
                    if os.path.isfile(
                            self.configuration['GENERAL'][
                                'global_arguments']['database']
                    ):
                        self.database = database.DataBase(
                            self.configuration['GENERAL'][
                                'global_arguments']['database']
                        )
                    else:
                        raise FileNotFoundError(
                            'domain file %s not found' %
                            self.configuration['GENERAL'][
                                'global_arguments']['ontology'])

                if 'goals_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['goals_path']
                    ):
                        self.goals_path = \
                            self.configuration['DIALOGUE']['goals_path']
                    else:
                        raise FileNotFoundError(
                            'Goals file %s not found' %
                            self.configuration['DIALOGUE']['goals_path']
                        )

            # General settings
            if 'GENERAL' in self.configuration and \
                    self.configuration['GENERAL']:
                if 'global_arguments' in self.configuration['GENERAL']:
                    self.global_args = \
                        self.configuration['GENERAL']['global_arguments']
                    
                if 'experience_logs' in self.configuration['GENERAL']:
                    dialogues_path = None
                    if 'path' in \
                            self.configuration['GENERAL']['experience_logs']:
                        dialogues_path = \
                            self.configuration['GENERAL'][
                                'experience_logs']['path']

                    if 'load' in \
                            self.configuration['GENERAL']['experience_logs'] \
                            and bool(self.configuration['GENERAL'][
                                         'experience_logs']['load']):
                        if dialogues_path and os.path.isfile(dialogues_path):
                            self.recorder.load(dialogues_path)
                        else:
                            raise FileNotFoundError(
                                'dialogue Log file %s not found (did you '
                                'provide one?)' % dialogues_path)

                    if 'save' in \
                            self.configuration['GENERAL']['experience_logs']:
                        self.recorder.set_path(dialogues_path)
                        self.SAVE_LOG = bool(
                            self.configuration['GENERAL'][
                                'experience_logs']['save']
                        )

                if self.configuration['GENERAL']['interaction_mode'] == \
                        'simulation':
                    self.USE_USR_SIMULATOR = True

                elif self.configuration['GENERAL']['interaction_mode'] == \
                        'speech':
                    self.USE_SPEECH = True
                    self.asr = speech_rec.Recognizer()

            # Agent Settings

            # Retrieve agent role
            if 'role' in self.configuration['AGENT_0']:
                self.agent_role = self.configuration['AGENT_0']['role']
            else:
                raise ValueError(
                    'agent: No role assigned for agent {0} in '
                    'config!'.format(self.agent_id)
                )

            if self.agent_role == 'user':
                if self.ontology and self.database:
                    self.goal_generator = GoalGenerator({
                        'ontology': self.ontology,
                        'database': self.database
                    })
                else:
                    raise ValueError(
                        'Conversational Single Agent (user): Cannot generate '
                        'goal without ontology and database.'
                    )
                
            # Retrieve agent parameters
            if 'max_turns' in self.configuration['AGENT_0']:
                self.MAX_TURNS = self.configuration['AGENT_0']['max_turns']
                
            if 'train_interval' in self.configuration['AGENT_0']:
                self.train_interval = \
                    self.configuration['AGENT_0']['train_interval']
                
            if 'train_minibatch' in self.configuration['AGENT_0']:
                self.minibatch_length = \
                    self.configuration['AGENT_0']['train_minibatch']
                
            if 'train_epochs' in self.configuration['AGENT_0']:
                self.train_epochs = \
                    self.configuration['AGENT_0']['train_epochs']

            if 'save_interval' in self.configuration['AGENT_0']:
                self.SAVE_INTERVAL = \
                    self.configuration['AGENT_0']['save_interval']
            
            # usr Simulator
            # Check for specific simulator self.settings, otherwise
            # default to agenda
            if 'USER_SIMULATOR' in self.configuration['AGENT_0']:
                # Agent 0 simulator configuration
                if 'package' in \
                    self.configuration['AGENT_0']['USER_SIMULATOR'] and \
                        'class' in \
                        self.configuration['AGENT_0']['USER_SIMULATOR']:
                    if 'arguments' in \
                            self.configuration['AGENT_0']['USER_SIMULATOR']:
                        self.user_simulator_args =\
                            self.configuration[
                                'AGENT_0']['USER_SIMULATOR']['arguments']

                    self.user_simulator_args.update(self.global_args)
                    
                    self.user_simulator = \
                        ConversationalGenericAgent.load_module(
                            self.configuration['AGENT_0']['USER_SIMULATOR'][
                                'package'],
                            self.configuration['AGENT_0']['USER_SIMULATOR'][
                                'class'],
                            self.user_simulator_args
                        )

                    if hasattr(self.user_simulator, 'nlu'):
                        self.USER_SIMULATOR_NLU = self.user_simulator.nlu
                    if hasattr(self.user_simulator, 'nlg'):
                        self.USER_SIMULATOR_NLG = self.user_simulator.nlg
                else:
                    # Fallback to agenda based simulator with default settings
                    self.user_simulator = AgendaBasedUS(
                        self.user_simulator_args
                    )

            # NLU Settings
            if 'NLU' in self.configuration['AGENT_0']:
                nlu_args = {}
                if 'package' in self.configuration['AGENT_0']['NLU'] and \
                        'class' in self.configuration['AGENT_0']['NLU']:
                    if 'arguments' in \
                            self.configuration['AGENT_0']['NLU']:
                        nlu_args = \
                            self.configuration['AGENT_0']['NLU']['arguments']

                    nlu_args.update(self.global_args)

                    self.nlu = \
                        ConversationalGenericAgent.load_module(
                            self.configuration['AGENT_0']['NLU'][
                                'package'],
                            self.configuration['AGENT_0']['NLU'][
                                'class'],
                            nlu_args
                        )
                    
            # DM Settings
            if 'DM' in self.configuration['AGENT_0']:
                dm_args = dict(
                    zip(
                        ['settings', 'ontology', 'database', 'domain',
                         'agent_id',
                         'agent_role'],
                        [self.configuration,
                         self.ontology,
                         self.database,
                         self.domain,
                         self.agent_id,
                         self.agent_role
                         ]
                    )
                )

                if 'package' in self.configuration['AGENT_0']['DM'] and \
                        'class' in self.configuration['AGENT_0']['DM']:
                    if 'arguments' in \
                            self.configuration['AGENT_0']['DM']:
                        dm_args.update(
                            self.configuration['AGENT_0']['DM']['arguments']
                        )

                    dm_args.update(self.global_args)

                    self.dialogue_manager = \
                        ConversationalGenericAgent.load_module(
                            self.configuration['AGENT_0']['DM'][
                                'package'],
                            self.configuration['AGENT_0']['DM'][
                                'class'],
                            dm_args
                        )

            # NLG Settings
            if 'NLG' in self.configuration['AGENT_0']:
                nlg_args = {}
                if 'package' in self.configuration['AGENT_0']['NLG'] and \
                        'class' in self.configuration['AGENT_0']['NLG']:
                    if 'arguments' in \
                            self.configuration['AGENT_0']['NLG']:
                        nlg_args = \
                            self.configuration['AGENT_0']['NLG']['arguments']

                    nlg_args.update(self.global_args)

                    self.nlg = \
                        ConversationalGenericAgent.load_module(
                            self.configuration['AGENT_0']['NLG'][
                                'package'],
                            self.configuration['AGENT_0']['NLG'][
                                'class'],
                            nlg_args
                        )

                if self.nlg:
                    self.USE_NLG = True

        # True if at least one module is training
        self.IS_TRAINING = self.nlu and self.nlu.training or \
            self.dialogue_manager and self.dialogue_manager.training or \
            self.nlg and self.nlg.training

    def __del__(self):
        """
        Do some house-keeping and save the models.

        :return: nothing
        """

        if self.recorder and self.SAVE_LOG:
            self.recorder.save()

        if self.nlu:
            self.nlu.save()

        if self.dialogue_manager:
            self.dialogue_manager.save()

        if self.nlg:
            self.nlg.save()

        self.curr_state = None
        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

    def initialize(self):
        """
        Initializes the conversational agent based on settings in the
        configuration file.

        :return: Nothing
        """

        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0

        if self.nlu:
            self.nlu.initialize({})

        self.dialogue_manager.initialize({})

        if self.nlg:
            self.nlg.initialize({})

        self.curr_state = None
        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

    def start_dialogue(self, args=None):
        """
        Perform initial dialogue turn.

        :param args: optional args
        :return:
        """

        self.dialogue_turn = 0
        sys_utterance = ''

        if self.USE_USR_SIMULATOR:
            self.user_simulator.initialize(self.user_simulator_args)

            print('DEBUG > usr goal:')
            print(self.user_simulator.goal)

        self.dialogue_manager.restart({})

        if not self.USER_HAS_INITIATIVE:
            # sys_response = self.dialogue_manager.respond()
            sys_response = [DialogueAct('welcomemsg', [])]

            if self.USE_NLG:
                sys_utterance = self.nlg.generate_output(
                    {'dacts': sys_response}
                )
                print('SYSTEM > %s ' % sys_utterance)

                if self.USE_SPEECH:
                    try:
                        tts = gTTS(sys_utterance)
                        tts.save('sys_output.mp3')
                        os.system('afplay sys_output.mp3')

                    except Exception as e:
                        print(
                            'WARNING: gTTS encountered an error: {0}. '
                            'Falling back to sys TTS.'.format(e)
                        )
                        os.system('say ' + sys_utterance)
            else:
                print(
                    'SYSTEM > %s ' % '; '.
                    join([str(sr) for sr in sys_response])
                )

            if self.USE_USR_SIMULATOR:
                usim_input = sys_response

                if self.USER_SIMULATOR_NLU and self.USE_NLG:
                    usim_input = self.user_simulator.nlu.process_input(
                        sys_utterance
                    )

                self.user_simulator.receive_input(usim_input)
                rew, success, task_success = self.reward_func.calculate(
                    self.dialogue_manager.get_state(),
                    sys_response,
                    self.user_simulator.goal
                )
            else:
                rew, success, task_success = 0, None, None

            self.recorder.record(
                deepcopy(self.dialogue_manager.get_state()),
                self.dialogue_manager.get_state(),
                sys_response,
                rew,
                success,
                task_success,
                output_utterance=sys_utterance
            )

            self.dialogue_turn += 1

        self.prev_state = None

        # Re-initialize these for good measure
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

        self.continue_dialogue()

    def continue_dialogue(self):
        """
        Perform next dialogue turn.

        :return: nothing
        """

        usr_utterance = ''
        sys_utterance = ''

        if self.USE_USR_SIMULATOR:
            usr_input = self.user_simulator.respond()

            # TODO: THIS FIRST IF WILL BE HANDLED BY ConversationalAgentGeneric
            #  -- SHOULD NOT LIVE HERE
            if isinstance(self.user_simulator, DTLUserSimulator):
                print('USER (nlg) > %s \n' % usr_input)
                usr_input = self.nlu.process_input(
                    usr_input,
                    self.dialogue_manager.get_state()
                )

            elif self.USER_SIMULATOR_NLG:
                print('USER > %s \n' % usr_input)

                if self.nlu:
                    usr_input = self.nlu.process_input(usr_input)

                    # Otherwise it will just print the user's nlg but use the
                    # simulator's output DActs to proceed.

            else:
                print('USER (DACT) > %s \n' % '; '.join(
                    [str(ui) for ui in usr_input]))

        else:
            if self.USE_SPEECH:
                # Listen for input from the microphone
                with speech_rec.Microphone() as source:
                    print('(listening...)')
                    audio = self.asr.listen(source, phrase_time_limit=3)

                try:
                    # This uses the default key
                    usr_utterance = self.asr.recognize_google(audio)
                    print("Google ASR: " + usr_utterance)

                except speech_rec.UnknownValueError:
                    print("Google ASR did not understand you")

                except speech_rec.RequestError as e:
                    print("Google ASR request error: {0}".format(e))

            else:
                usr_utterance = input('USER > ')

            # Process the user's utterance
            if self.nlu:
                usr_input = self.nlu.process_input(
                    usr_utterance,
                    self.dialogue_manager.get_state()
                )
            else:
                raise EnvironmentError(
                    'agent: No nlu defined for '
                    'text-based interaction!'
                )

        self.dialogue_manager.receive_input(usr_input)

        # Keep track of prev_state, for the DialogueEpisodeRecorder
        # Store here because this is the state that the dialogue manager
        # will use to make a decision.
        self.curr_state = deepcopy(self.dialogue_manager.get_state())

        if self.dialogue_turn < self.MAX_TURNS:
            sys_response = self.dialogue_manager.generate_output()

        else:
            # Force dialogue stop
            sys_response = [DialogueAct('bye', [])]

        if self.USE_NLG:
            sys_utterance = self.nlg.generate_output({'dacts': sys_response})
            print('SYSTEM > %s ' % sys_utterance)

            if self.USE_SPEECH:
                try:
                    tts = gTTS(text=sys_utterance, lang='en')
                    tts.save('sys_output.mp3')
                    os.system('afplay sys_output.mp3')

                except:
                    print('WARNING: gTTS encountered an error. '
                          'Falling back to sys TTS.')
                    os.system('say ' + sys_utterance)
        else:
            print('SYSTEM > %s ' % '; '.join([str(sr) for sr in sys_response]))

        if self.USE_USR_SIMULATOR:
            usim_input = sys_response

            if self.USER_SIMULATOR_NLU and self.USE_NLG:
                usim_input = \
                    self.user_simulator.nlu.process_input(sys_utterance)

            self.user_simulator.receive_input(usim_input)
            rew, success, task_success = \
                self.reward_func.calculate(
                    self.dialogue_manager.get_state(),
                    sys_response,
                    self.user_simulator.goal
                )
        else:
            rew, success, task_success = 0, None, None

        if self.prev_state:
            self.recorder.record(
                self.prev_state,
                self.curr_state,
                self.prev_action,
                self.prev_reward,
                self.prev_success,
                input_utterance=self.prev_usr_utterance,
                output_utterance=self.prev_sys_utterance,
                task_success=self.prev_task_success
            )

        self.dialogue_turn += 1

        self.prev_state = deepcopy(self.curr_state)
        self.prev_action = deepcopy(sys_response)
        self.prev_usr_utterance = deepcopy(usr_utterance)
        self.prev_sys_utterance = deepcopy(sys_utterance)
        self.prev_reward = rew
        self.prev_success = success
        self.prev_task_success = task_success

    def end_dialogue(self):
        """
        Perform final dialogue turn. Train and save models if applicable.

        :return: nothing
        """

        # Record final state
        self.recorder.record(
            self.curr_state,
            self.curr_state,
            self.prev_action,
            self.prev_reward,
            self.prev_success,
            input_utterance=self.prev_usr_utterance,
            output_utterance=self.prev_sys_utterance,
            task_success=self.prev_task_success,
            force_terminate=True
        )

        self.dialogue_episode += 1

        if self.IS_TRAINING:
            if self.dialogue_episode % self.train_interval == 0 and \
                    len(self.recorder.dialogues) >= self.minibatch_length:
                for epoch in range(self.train_epochs):
                    print('Training epoch {0} of {1}'.format(
                        (epoch+1),
                        self.train_epochs)
                    )

                    # Sample minibatch
                    minibatch = random.sample(
                        self.recorder.dialogues,
                        self.minibatch_length
                    )

                    if self.nlu and self.nlu.training:
                        self.nlu.train(minibatch)

                    if self.dialogue_manager.is_training():
                        self.dialogue_manager.train(minibatch)

                    if self.nlg and self.nlg.training:
                        self.nlg.train(minibatch)

        # Keep track of dialogue statistics
        self.cumulative_rewards += \
            self.recorder.dialogues[-1][-1]['cumulative_reward']
        print('CUMULATIVE REWARD: {0}'.
              format(self.recorder.dialogues[-1][-1]['cumulative_reward']))

        if self.dialogue_turn > 0:
            self.total_dialogue_turns += self.dialogue_turn

        if self.dialogue_episode % self.SAVE_INTERVAL == 0:
            if self.nlu:
                self.nlu.save()

            if self.dialogue_manager:
                self.dialogue_manager.save()

            if self.nlg:
                self.nlg.save()

        # Count successful dialogues
        if self.recorder.dialogues[-1][-1]['success']:
            print('SUCCESS (Subjective)!')
            self.num_successful_dialogues += \
                int(self.recorder.dialogues[-1][-1]['success'])

        else:
            print('FAILURE (Subjective).')

        if self.recorder.dialogues[-1][-1]['task_success']:
            self.num_task_success += \
                int(self.recorder.dialogues[-1][-1]['task_success'])

        print('OBJECTIVE TASK SUCCESS: {0}'.
              format(self.recorder.dialogues[-1][-1]['task_success']))

    def terminated(self):
        """
        Check if this agent is at a terminal state.

        :return: True or False
        """

        return self.dialogue_manager.at_terminal_state() or \
            self.dialogue_turn > self.MAX_TURNS
    def __init__(self, configuration, agent_id):
        """
        Initialize the internal structures of this agent.

        :param configuration: a dictionary representing the configuration file
        :param agent_id: an integer, this agent's id
        """

        super(ConversationalMultiAgent, self).__init__()

        self.agent_id = agent_id

        # Flag to alternate training between roles
        self.train_system = True

        # dialogue statistics
        self.dialogue_episode = 1
        # Note that since this is a multi-agent setting, dialogue_turn refers
        # to this agent's turns only
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0
        self.total_dialogue_turns = 0

        self.minibatch_length = 200
        self.train_interval = 50
        self.train_epochs = 3
        self.global_args = {}

        # Alternate training between the agents
        self.train_alternate_training = True
        self.train_switch_trainable_agents_every = self.train_interval

        self.configuration = configuration

        # True values here would imply some default modules
        self.USE_NLU = False
        self.USE_NLG = False
        self.USE_SPEECH = False
        self.USER_HAS_INITIATIVE = True
        self.SAVE_LOG = True
        self.SAVE_INTERVAL = 10000

        # The dialogue will terminate after MAX_TURNS (this agent will issue
        # a bye() dialogue act.
        self.MAX_TURNS = 10

        self.ontology = None
        self.database = None
        self.domain = None
        self.dialogue_manager = None
        self.user_model = None
        self.nlu = None
        self.nlg = None

        self.agent_role = None
        self.agent_goal = None
        self.goal_generator = None
        self.goals_path = None

        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

        self.user_model = UserModel()

        # The size of the experience pool is a hyperparameter.

        # Do not have an experience window larger than the current batch,
        # as past experience may not be relevant since both agents learn.
        self.recorder = DialogueEpisodeRecorder(size=20000)

        # TODO: Get reward type from the config
        self.reward_func = SlotFillingReward()

        if self.configuration:
            ag_id_str = 'AGENT_' + str(self.agent_id)
            
            # Error checks for options the config must have
            if not self.configuration['GENERAL']:
                raise ValueError('Cannot run Plato without GENERAL settings!')

            elif not self.configuration['GENERAL']['interaction_mode']:
                raise ValueError('Cannot run Plato without an interaction '
                                 'mode!')

            elif not self.configuration['DIALOGUE']:
                raise ValueError('Cannot run Plato without DIALOGUE settings!')

            elif not self.configuration[ag_id_str]:
                raise ValueError('Cannot run Plato without at least '
                                 'one agent!')

            # General settings
            if 'GENERAL' in self.configuration and \
                    self.configuration['GENERAL']:
                if 'global_arguments' in self.configuration['GENERAL']:
                    self.global_args = \
                        self.configuration['GENERAL']['global_arguments']

                if 'experience_logs' in self.configuration['GENERAL']:
                    dialogues_path = None
                    if 'path' in \
                            self.configuration['GENERAL']['experience_logs']:
                        dialogues_path = \
                            self.configuration['GENERAL'][
                                'experience_logs']['path']

                    if 'load' in \
                            self.configuration['GENERAL'][
                                'experience_logs'] and \
                            bool(
                                self.configuration[
                                    'GENERAL'
                                ]['experience_logs']['load']
                            ):
                        if dialogues_path and os.path.isfile(dialogues_path):
                            self.recorder.load(dialogues_path)
                        else:
                            raise FileNotFoundError(
                                'dialogue Log file %s not found (did you '
                                'provide one?)' % dialogues_path
                            )

                    if 'save' in \
                            self.configuration['GENERAL']['experience_logs']:
                        self.recorder.set_path(dialogues_path)
                        self.SAVE_LOG = bool(
                            self.configuration['GENERAL'][
                                'experience_logs']['save']
                        )

                        # Retrieve agent role
                        if 'role' in self.configuration[ag_id_str]:
                            self.agent_role = self.configuration[ag_id_str][
                                'role']
                        else:
                            raise ValueError(
                                'agent: No role assigned for '
                                'agent {0} in '
                                'config!'.format(self.agent_id)
                            )

            # Dialogue settings
            if 'DIALOGUE' in self.configuration and \
                    self.configuration['DIALOGUE']:
                if 'initiative' in self.configuration['DIALOGUE']:
                    self.USER_HAS_INITIATIVE = bool(
                        self.configuration['DIALOGUE']['initiative'] == 'user'
                    )

                if 'domain' in self.configuration['DIALOGUE']:
                    self.domain = self.configuration['DIALOGUE']['domain']
                elif 'domain' in self.global_args:
                    self.domain = self.global_args['domain']

                ontology_path = None
                if 'ontology_path' in self.configuration['DIALOGUE']:
                    ontology_path = \
                        self.configuration['DIALOGUE']['ontology_path']
                elif 'ontology' in self.global_args:
                    ontology_path = self.global_args['ontology']

                if ontology_path and os.path.isfile(ontology_path):
                    self.ontology = ontology.Ontology(ontology_path)
                else:
                    raise FileNotFoundError(
                        'domain file %s not found' % ontology_path
                    )

                db_path = None
                if 'db_path' in self.configuration['DIALOGUE']:
                    db_path = self.configuration['DIALOGUE']['db_path']
                elif 'database' in self.global_args:
                    db_path = self.global_args['database']

                if db_path and os.path.isfile(db_path):
                    if db_path[-3:] == '.db':
                        self.database = database.SQLDataBase(db_path)
                    else:
                        self.database = database.DataBase(db_path)
                else:
                    raise FileNotFoundError(
                        'Database file %s not found' % db_path
                    )

                if 'goals_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                        self.configuration['DIALOGUE']['goals_path']
                    ):
                        self.goals_path = \
                            self.configuration['DIALOGUE']['goals_path']
                    else:
                        raise FileNotFoundError(
                            'Goals file %s not '
                            'found' % self.configuration[
                                'DIALOGUE'
                            ]['goals_path']
                        )

            # Agent Settings

            # Retrieve agent role
            if 'role' in self.configuration[ag_id_str]:
                self.agent_role = self.configuration[ag_id_str][
                    'role']
            else:
                raise ValueError(
                    'agent: No role assigned for agent {0} in '
                    'config!'.format(self.agent_id)
                )

            if self.agent_role == 'user':
                if self.ontology and self.database:
                    self.goal_generator = GoalGenerator({
                        'ontology': self.ontology,
                        'database': self.database
                    })
                else:
                    raise ValueError(
                        'Conversational Multi Agent (user): Cannot generate '
                        'goal without ontology and database.'
                    )

            # Retrieve agent parameters
            if 'max_turns' in self.configuration[ag_id_str]:
                self.MAX_TURNS = self.configuration[ag_id_str][
                    'max_turns']

            if 'train_interval' in self.configuration[ag_id_str]:
                self.train_interval = \
                    self.configuration[ag_id_str]['train_interval']

            if 'train_minibatch' in self.configuration[ag_id_str]:
                self.minibatch_length = \
                    self.configuration[ag_id_str]['train_minibatch']

            if 'train_epochs' in self.configuration[ag_id_str]:
                self.train_epochs = \
                    self.configuration[ag_id_str]['train_epochs']

            if 'save_interval' in self.configuration[ag_id_str]:
                self.SAVE_INTERVAL = \
                    self.configuration[ag_id_str]['save_interval']
                
            # NLU Settings
            if 'NLU' in self.configuration[ag_id_str]:
                nlu_args = {}
                if 'package' in self.configuration[ag_id_str]['NLU'] and \
                        'class' in self.configuration[ag_id_str]['NLU']:
                    if 'arguments' in \
                            self.configuration[ag_id_str]['NLU']:
                        nlu_args = \
                            self.configuration[ag_id_str]['NLU'][
                                'arguments']

                    nlu_args.update(self.global_args)

                    self.nlu = \
                        ConversationalGenericAgent.load_module(
                            self.configuration[ag_id_str]['NLU'][
                                'package'],
                            self.configuration[ag_id_str]['NLU'][
                                'class'],
                            nlu_args
                        )

            # DM Settings
            if 'DM' in self.configuration[ag_id_str]:
                dm_args = dict(
                    zip(
                        ['settings', 'ontology', 'database',
                         'domain',
                         'agent_id',
                         'agent_role'],
                        [self.configuration,
                         self.ontology,
                         self.database,
                         self.domain,
                         self.agent_id,
                         self.agent_role
                         ]
                    )
                )

                if 'package' in self.configuration[ag_id_str]['DM'] and \
                        'class' in self.configuration[ag_id_str]['DM']:
                    if 'arguments' in \
                            self.configuration[ag_id_str]['DM']:
                        dm_args.update(
                            self.configuration[ag_id_str]['DM'][
                                'arguments']
                        )

                    dm_args.update(self.global_args)

                    self.dialogue_manager = \
                        ConversationalGenericAgent.load_module(
                            self.configuration[ag_id_str]['DM'][
                                'package'],
                            self.configuration[ag_id_str]['DM'][
                                'class'],
                            dm_args
                        )

            # NLG Settings
            if 'NLG' in self.configuration[ag_id_str]:
                nlg_args = {}
                if 'package' in self.configuration[ag_id_str]['NLG'] and \
                        'class' in self.configuration[ag_id_str]['NLG']:
                    if 'arguments' in \
                            self.configuration[ag_id_str]['NLG']:
                        nlg_args = \
                            self.configuration[ag_id_str]['NLG'][
                                'arguments']

                    nlg_args.update(self.global_args)

                    self.nlg = \
                        ConversationalGenericAgent.load_module(
                            self.configuration[ag_id_str]['NLG'][
                                'package'],
                            self.configuration[ag_id_str]['NLG'][
                                'class'],
                            nlg_args
                        )

                if self.nlg:
                    self.USE_NLG = True

        # True if at least one module is training
        self.IS_TRAINING = self.nlu and self.nlu.training or \
            self.dialogue_manager and self.dialogue_manager.training or \
            self.nlg and self.nlg.training
class ConversationalMultiAgent(ConversationalAgent):
    """
    Essentially the dialogue system. Will be able to interact with:

    - Simulated Users via:
        - dialogue Acts
        - Text

    - Human Users via:
        - Text
        - Speech
        - Online crowd?

    - parser
    """

    def __init__(self, configuration, agent_id):
        """
        Initialize the internal structures of this agent.

        :param configuration: a dictionary representing the configuration file
        :param agent_id: an integer, this agent's id
        """

        super(ConversationalMultiAgent, self).__init__()

        self.agent_id = agent_id

        # Flag to alternate training between roles
        self.train_system = True

        # dialogue statistics
        self.dialogue_episode = 1
        # Note that since this is a multi-agent setting, dialogue_turn refers
        # to this agent's turns only
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0
        self.total_dialogue_turns = 0

        self.minibatch_length = 200
        self.train_interval = 50
        self.train_epochs = 3
        self.global_args = {}

        # Alternate training between the agents
        self.train_alternate_training = True
        self.train_switch_trainable_agents_every = self.train_interval

        self.configuration = configuration

        # True values here would imply some default modules
        self.USE_NLU = False
        self.USE_NLG = False
        self.USE_SPEECH = False
        self.USER_HAS_INITIATIVE = True
        self.SAVE_LOG = True
        self.SAVE_INTERVAL = 10000

        # The dialogue will terminate after MAX_TURNS (this agent will issue
        # a bye() dialogue act.
        self.MAX_TURNS = 10

        self.ontology = None
        self.database = None
        self.domain = None
        self.dialogue_manager = None
        self.user_model = None
        self.nlu = None
        self.nlg = None

        self.agent_role = None
        self.agent_goal = None
        self.goal_generator = None
        self.goals_path = None

        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

        self.user_model = UserModel()

        # The size of the experience pool is a hyperparameter.

        # Do not have an experience window larger than the current batch,
        # as past experience may not be relevant since both agents learn.
        self.recorder = DialogueEpisodeRecorder(size=20000)

        # TODO: Get reward type from the config
        self.reward_func = SlotFillingReward()

        if self.configuration:
            ag_id_str = 'AGENT_' + str(self.agent_id)
            
            # Error checks for options the config must have
            if not self.configuration['GENERAL']:
                raise ValueError('Cannot run Plato without GENERAL settings!')

            elif not self.configuration['GENERAL']['interaction_mode']:
                raise ValueError('Cannot run Plato without an interaction '
                                 'mode!')

            elif not self.configuration['DIALOGUE']:
                raise ValueError('Cannot run Plato without DIALOGUE settings!')

            elif not self.configuration[ag_id_str]:
                raise ValueError('Cannot run Plato without at least '
                                 'one agent!')

            # General settings
            if 'GENERAL' in self.configuration and \
                    self.configuration['GENERAL']:
                if 'global_arguments' in self.configuration['GENERAL']:
                    self.global_args = \
                        self.configuration['GENERAL']['global_arguments']

                if 'experience_logs' in self.configuration['GENERAL']:
                    dialogues_path = None
                    if 'path' in \
                            self.configuration['GENERAL']['experience_logs']:
                        dialogues_path = \
                            self.configuration['GENERAL'][
                                'experience_logs']['path']

                    if 'load' in \
                            self.configuration['GENERAL'][
                                'experience_logs'] and \
                            bool(
                                self.configuration[
                                    'GENERAL'
                                ]['experience_logs']['load']
                            ):
                        if dialogues_path and os.path.isfile(dialogues_path):
                            self.recorder.load(dialogues_path)
                        else:
                            raise FileNotFoundError(
                                'dialogue Log file %s not found (did you '
                                'provide one?)' % dialogues_path
                            )

                    if 'save' in \
                            self.configuration['GENERAL']['experience_logs']:
                        self.recorder.set_path(dialogues_path)
                        self.SAVE_LOG = bool(
                            self.configuration['GENERAL'][
                                'experience_logs']['save']
                        )

                        # Retrieve agent role
                        if 'role' in self.configuration[ag_id_str]:
                            self.agent_role = self.configuration[ag_id_str][
                                'role']
                        else:
                            raise ValueError(
                                'agent: No role assigned for '
                                'agent {0} in '
                                'config!'.format(self.agent_id)
                            )

            # Dialogue settings
            if 'DIALOGUE' in self.configuration and \
                    self.configuration['DIALOGUE']:
                if 'initiative' in self.configuration['DIALOGUE']:
                    self.USER_HAS_INITIATIVE = bool(
                        self.configuration['DIALOGUE']['initiative'] == 'user'
                    )

                if 'domain' in self.configuration['DIALOGUE']:
                    self.domain = self.configuration['DIALOGUE']['domain']
                elif 'domain' in self.global_args:
                    self.domain = self.global_args['domain']

                ontology_path = None
                if 'ontology_path' in self.configuration['DIALOGUE']:
                    ontology_path = \
                        self.configuration['DIALOGUE']['ontology_path']
                elif 'ontology' in self.global_args:
                    ontology_path = self.global_args['ontology']

                if ontology_path and os.path.isfile(ontology_path):
                    self.ontology = ontology.Ontology(ontology_path)
                else:
                    raise FileNotFoundError(
                        'domain file %s not found' % ontology_path
                    )

                db_path = None
                if 'db_path' in self.configuration['DIALOGUE']:
                    db_path = self.configuration['DIALOGUE']['db_path']
                elif 'database' in self.global_args:
                    db_path = self.global_args['database']

                if db_path and os.path.isfile(db_path):
                    if db_path[-3:] == '.db':
                        self.database = database.SQLDataBase(db_path)
                    else:
                        self.database = database.DataBase(db_path)
                else:
                    raise FileNotFoundError(
                        'Database file %s not found' % db_path
                    )

                if 'goals_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                        self.configuration['DIALOGUE']['goals_path']
                    ):
                        self.goals_path = \
                            self.configuration['DIALOGUE']['goals_path']
                    else:
                        raise FileNotFoundError(
                            'Goals file %s not '
                            'found' % self.configuration[
                                'DIALOGUE'
                            ]['goals_path']
                        )

            # Agent Settings

            # Retrieve agent role
            if 'role' in self.configuration[ag_id_str]:
                self.agent_role = self.configuration[ag_id_str][
                    'role']
            else:
                raise ValueError(
                    'agent: No role assigned for agent {0} in '
                    'config!'.format(self.agent_id)
                )

            if self.agent_role == 'user':
                if self.ontology and self.database:
                    self.goal_generator = GoalGenerator({
                        'ontology': self.ontology,
                        'database': self.database
                    })
                else:
                    raise ValueError(
                        'Conversational Multi Agent (user): Cannot generate '
                        'goal without ontology and database.'
                    )

            # Retrieve agent parameters
            if 'max_turns' in self.configuration[ag_id_str]:
                self.MAX_TURNS = self.configuration[ag_id_str][
                    'max_turns']

            if 'train_interval' in self.configuration[ag_id_str]:
                self.train_interval = \
                    self.configuration[ag_id_str]['train_interval']

            if 'train_minibatch' in self.configuration[ag_id_str]:
                self.minibatch_length = \
                    self.configuration[ag_id_str]['train_minibatch']

            if 'train_epochs' in self.configuration[ag_id_str]:
                self.train_epochs = \
                    self.configuration[ag_id_str]['train_epochs']

            if 'save_interval' in self.configuration[ag_id_str]:
                self.SAVE_INTERVAL = \
                    self.configuration[ag_id_str]['save_interval']
                
            # NLU Settings
            if 'NLU' in self.configuration[ag_id_str]:
                nlu_args = {}
                if 'package' in self.configuration[ag_id_str]['NLU'] and \
                        'class' in self.configuration[ag_id_str]['NLU']:
                    if 'arguments' in \
                            self.configuration[ag_id_str]['NLU']:
                        nlu_args = \
                            self.configuration[ag_id_str]['NLU'][
                                'arguments']

                    nlu_args.update(self.global_args)

                    self.nlu = \
                        ConversationalGenericAgent.load_module(
                            self.configuration[ag_id_str]['NLU'][
                                'package'],
                            self.configuration[ag_id_str]['NLU'][
                                'class'],
                            nlu_args
                        )

            # DM Settings
            if 'DM' in self.configuration[ag_id_str]:
                dm_args = dict(
                    zip(
                        ['settings', 'ontology', 'database',
                         'domain',
                         'agent_id',
                         'agent_role'],
                        [self.configuration,
                         self.ontology,
                         self.database,
                         self.domain,
                         self.agent_id,
                         self.agent_role
                         ]
                    )
                )

                if 'package' in self.configuration[ag_id_str]['DM'] and \
                        'class' in self.configuration[ag_id_str]['DM']:
                    if 'arguments' in \
                            self.configuration[ag_id_str]['DM']:
                        dm_args.update(
                            self.configuration[ag_id_str]['DM'][
                                'arguments']
                        )

                    dm_args.update(self.global_args)

                    self.dialogue_manager = \
                        ConversationalGenericAgent.load_module(
                            self.configuration[ag_id_str]['DM'][
                                'package'],
                            self.configuration[ag_id_str]['DM'][
                                'class'],
                            dm_args
                        )

            # NLG Settings
            if 'NLG' in self.configuration[ag_id_str]:
                nlg_args = {}
                if 'package' in self.configuration[ag_id_str]['NLG'] and \
                        'class' in self.configuration[ag_id_str]['NLG']:
                    if 'arguments' in \
                            self.configuration[ag_id_str]['NLG']:
                        nlg_args = \
                            self.configuration[ag_id_str]['NLG'][
                                'arguments']

                    nlg_args.update(self.global_args)

                    self.nlg = \
                        ConversationalGenericAgent.load_module(
                            self.configuration[ag_id_str]['NLG'][
                                'package'],
                            self.configuration[ag_id_str]['NLG'][
                                'class'],
                            nlg_args
                        )

                if self.nlg:
                    self.USE_NLG = True

        # True if at least one module is training
        self.IS_TRAINING = self.nlu and self.nlu.training or \
            self.dialogue_manager and self.dialogue_manager.training or \
            self.nlg and self.nlg.training

    def __del__(self):
        """
        Do some house-keeping and save the models.

        :return: nothing
        """

        if self.recorder and self.SAVE_LOG:
            self.recorder.save()

        if self.nlu:
            self.nlu.save()

        if self.dialogue_manager:
            self.dialogue_manager.save()

        if self.nlg:
            self.nlg.save()

        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None

    def initialize(self):
        """
        Initializes the conversational agent based on settings in the
        configuration file.

        :return: Nothing
        """

        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0

        if self.nlu:
            self.nlu.initialize({})

        self.dialogue_manager.initialize({})

        if self.nlg:
            self.nlg.initialize({})

        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

    def start_dialogue(self, args=None):
        """
        Perform initial dialogue turn.

        :param args: optional arguments
        :return:
        """

        self.dialogue_turn = 0

        if self.agent_role == 'user':
            self.agent_goal = self.goal_generator.generate()
            self.dialogue_manager.update_goal(self.agent_goal)

            print('DEBUG > usr goal:')
            for c in self.agent_goal.constraints:
                print(f'\t\tConstr({self.agent_goal.constraints[c].slot}='
                      f'{self.agent_goal.constraints[c].value})')
            print('\t\t-------------')
            for r in self.agent_goal.requests:
                print(f'\t\tReq({self.agent_goal.requests[r].slot})')
            print('\n')

        elif args and 'goal' in args:
            # No deep copy here so that all agents see the same goal.
            self.agent_goal = args['goal']
        else:
            raise ValueError(
                'ConversationalMultiAgent - no goal provided '
                'for agent {0}!'.format(self.agent_role)
            )

        self.dialogue_manager.restart({'goal': self.agent_goal})

        response = [DialogueAct('welcomemsg', [])]
        response_utterance = ''

        # The above forces the DM's initial utterance to be the welcomemsg act.
        # The code below can be used to push an empty list of DialogueActs to
        # the DM and get a response from the model.
        # self.dialogue_manager.receive_input([])
        # response = self.dialogue_manager.respond()

        if self.agent_role == 'system':
            response = [DialogueAct('welcomemsg', [])]
            if self.USE_NLG:
                response_utterance = self.nlg.generate_output(
                    {
                        'dacts': response,
                        'system': self.agent_role == 'system',
                        'last_sys_utterance': ''
                    }
                )

                print(
                    '{0} > {1}'.format(
                        self.agent_role.upper(),
                        response_utterance
                    )
                )

                if self.USE_SPEECH:
                    tts = gTTS(text=response_utterance, lang='en')
                    tts.save('sys_output.mp3')
                    os.system('afplay sys_output.mp3')
            else:
                print(
                    '{0} > {1}'.format(
                        self.agent_role.upper(),
                        '; '.join([str(sr) for sr in response])
                    )
                )

        # TODO: Generate output depending on initiative - i.e.
        # have users also start the dialogue

        self.prev_state = None

        # Re-initialize these for good measure
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None

        return {'input_utterance': None,
                'output_raw': response_utterance,
                'output_dacts': response,
                'goal': self.agent_goal
                }

    def continue_dialogue(self, args):
        """
        Perform next dialogue turn.

        :return: this agent's output
        """

        if 'other_input_raw' not in args:
            raise ValueError(
                'ConversationalMultiAgent called without raw input!'
            )

        other_input_raw = args['other_input_raw']

        other_input_dacts = None
        if 'other_input_dacts' in args:
            other_input_dacts = args['other_input_dacts']

        goal = None
        if 'goal' in args:
            goal = args['goal']

        if goal:
            self.agent_goal = goal

        sys_utterance = ''

        other_input_nlu = deepcopy(other_input_raw)

        if self.nlu and isinstance(other_input_raw, str):
            # Process the other agent's utterance
            other_input_nlu = self.nlu.process_input(
                other_input_raw,
                self.dialogue_manager.get_state()
            )

        elif other_input_dacts:
            # If no utterance provided, use the dacts
            other_input_nlu = other_input_dacts

        self.dialogue_manager.receive_input(other_input_nlu)

        # Keep track of prev_state, for the DialogueEpisodeRecorder
        # Store here because this is the state that the dialogue
        # manager will use to make a decision.
        self.curr_state = deepcopy(self.dialogue_manager.get_state())

        # Update goal's ground truth
        if self.agent_role == 'system':
            self.agent_goal.ground_truth = deepcopy(
                self.curr_state.item_in_focus
            )

        if self.dialogue_turn < self.MAX_TURNS:
            response = self.dialogue_manager.generate_output()
            self.agent_goal = self.dialogue_manager.DSTracker.DState.user_goal

        else:
            # Force dialogue stop
            response = [DialogueAct('bye', [])]

        rew, success, task_success = self.reward_func.calculate(
            self.dialogue_manager.get_state(),
            response,
            goal=self.agent_goal,
            agent_role=self.agent_role
        )

        if self.USE_NLG:
            sys_utterance = self.nlg.generate_output(
                {
                    'dacts': response,
                    'system': self.agent_role == 'system',
                    'last_sys_utterance': other_input_raw
                }
            ) + ' '

            print('{0} > {1}'.format(self.agent_role.upper(), sys_utterance))

            if self.USE_SPEECH:
                tts = gTTS(text=sys_utterance, lang='en')
                tts.save('sys_output.mp3')
                os.system('afplay sys_output.mp3')
        else:
            print(
                '{0} > {1} \n'.format(
                    self.agent_role.upper(),
                    '; '.join([str(sr) for sr in response])
                )
            )

        if self.prev_state:
            self.recorder.record(
                self.prev_state,
                self.curr_state,
                self.prev_action,
                self.prev_reward,
                self.prev_success,
                input_utterance=self.prev_usr_utterance,
                output_utterance=self.prev_sys_utterance,
                task_success=self.prev_task_success
            )

        self.dialogue_turn += 1

        self.prev_state = deepcopy(self.curr_state)
        self.prev_usr_utterance = deepcopy(other_input_raw)
        self.prev_sys_utterance = deepcopy(sys_utterance)
        self.prev_action = deepcopy(response)
        self.prev_reward = rew
        self.prev_success = success
        self.prev_task_success = task_success

        return {'input_utterance': None,
                'output_raw': sys_utterance,
                'output_dacts': response,
                'goal': self.agent_goal
                }

    def end_dialogue(self):
        """
        Perform final dialogue turn. Train and ave models if applicable.

        :return:
        """

        if self.dialogue_episode % \
                self.train_switch_trainable_agents_every == 0:
            self.train_system = not self.train_system

            # Record final state
            if self.curr_state:
                if not self.curr_state.is_terminal_state:
                    self.curr_state.is_terminal_state = True
                    self.prev_reward, self.prev_success, self.prev_task_success = \
                        self.reward_func.calculate(
                            self.curr_state,
                            [DialogueAct('bye', [])],
                            goal=self.agent_goal,
                            agent_role=self.agent_role
                        )
            else:
                print(
                    'Warning! Conversational Multi Agent attempted to end the'
                    'dialogue with no state. This dialogue will NOT be saved.')
                return

            self.recorder.record(
                self.curr_state,
                self.curr_state,
                self.prev_action,
                self.prev_reward,
                self.prev_success,
                input_utterance=self.prev_usr_utterance,
                output_utterance=self.prev_sys_utterance,
                task_success=self.prev_task_success,
                role=self.agent_role,
                force_terminate=True
            )

        self.dialogue_episode += 1

        if self.IS_TRAINING:
            if not self.train_alternate_training or \
                    (self.train_system and
                     self.agent_role == 'system' or
                     not self.train_system and
                     self.agent_role == 'user'):

                if self.dialogue_episode % self.train_interval == 0 and \
                        len(self.recorder.dialogues) >= self.minibatch_length:
                    for epoch in range(self.train_epochs):
                        print(
                            '{0}: Training epoch {1} of {2}'.format(
                                self.agent_role,
                                (epoch+1),
                                self.train_epochs
                            )
                        )

                        # Sample minibatch
                        minibatch = random.sample(
                            self.recorder.dialogues,
                            self.minibatch_length
                        )

                        if self.nlu:
                            self.nlu.train(minibatch)

                        self.dialogue_manager.train(minibatch)

                        if self.nlg:
                            self.nlg.train(minibatch)

        self.cumulative_rewards += \
            self.recorder.dialogues[-1][-1]['cumulative_reward']

        if self.dialogue_turn > 0:
            self.total_dialogue_turns += self.dialogue_turn

        if self.dialogue_episode % self.SAVE_INTERVAL == 0:
            if self.nlu:
                self.nlu.save()

            if self.dialogue_manager:
                self.dialogue_manager.save()

            if self.nlg:
                self.nlg.save()

        # Count successful dialogues
        if self.recorder.dialogues[-1][-1]['success']:
            print(
                '{0} SUCCESS! (reward: {1})'.format(
                    self.agent_role,
                    sum([t['reward'] for t in self.recorder.dialogues[-1]])
                )
            )
            self.num_successful_dialogues += \
                int(self.recorder.dialogues[-1][-1]['success'])

        else:
            print(
                '{0} FAILURE. (reward: {1})'.format(
                    self.agent_role,
                    sum([t['reward'] for t in self.recorder.dialogues[-1]])
                )
            )

        if self.recorder.dialogues[-1][-1]['task_success']:
            self.num_task_success += int(
                self.recorder.dialogues[-1][-1]['task_success']
            )

    def terminated(self):
        """
        Check if this agent is at a terminal state.

        :return: True or False
        """

        # Hard coded response to bye to enforce dialogue_policy according to
        # which if any agent issues a 'bye' then the dialogue
        # terminates. Otherwise in multi-agent settings it is very hard to
        # learn the association and learn to terminate
        # the dialogue.
        if self.dialogue_manager.get_state().user_acts:
            for act in self.dialogue_manager.get_state().user_acts:
                if act.intent == 'bye':
                    return True

        return self.dialogue_manager.at_terminal_state()

    def get_state(self):
        """
        Get this agent's state

        :return: a DialogueState
        """
        return self.dialogue_manager.get_state()

    def get_goal(self):
        """
        Get this agent's goal

        :return: a Goal
        """
        return self.agent_goal

    def set_goal(self, goal):
        """
        Set or update this agent's goal

        :param goal: a Goal
        :return: nothing
        """

        # Note: reason for non-deep copy is that if this agent changes the goal
        #       these changes are propagated to e.g. the reward function, and
        #       the reward calculation is up to date.
        self.agent_goal = goal
示例#5
0
    def __init__(self, configuration, agent_id):
        """
        Initialize the internal structures of this agent.

        :param configuration: a dictionary representing the configuration file
        :param agent_id: an integer, this agent's id
        """
        self.agent_id = agent_id

        # Dialogue statistics
        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0
        self.total_dialogue_turns = 0

        self.minibatch_length = 250
        self.train_interval = 50
        self.train_epochs = 10

        self.configuration = configuration

        self.recorder = DialogueEpisodeRecorder()

        self.SAVE_LOG = True
        self.SAVE_INTERVAL = 1000
        self.MAX_TURNS = 15
        self.INTERACTION_MODE = 'simulation'
        self.USE_GUI = False

        # This indicates which module controls the state so that we can query
        # it for dialogue termination (e.g. at end_dialogue)
        self.STATEFUL_MODULE = -1

        self.reward_func = SlotFillingGoalAdvancementReward()

        self.ConversationalModules = []
        self.prev_m_out = ConversationalFrame({})

        self.goal_generator = None
        self.agent_goal = None
        self.agent_role = ''

        ag_id_str = 'AGENT_' + str(agent_id)

        if self.configuration:
            if 'GENERAL' not in self.configuration:
                raise ValueError('No GENERAL section in config!')
            if 'AGENT_' + str(agent_id) not in self.configuration:
                raise ValueError(f'NO AGENT_{agent_id} section in config!')

            if 'role' in self.configuration[ag_id_str]:
                self.agent_role = self.configuration[ag_id_str]['role']

            # Retrieve agent parameters
            if 'max_turns' in self.configuration[ag_id_str]:
                self.MAX_TURNS = self.configuration[ag_id_str]['max_turns']

            if 'train_interval' in self.configuration[ag_id_str]:
                self.train_interval = \
                    self.configuration[ag_id_str]['train_interval']

            if 'train_minibatch' in self.configuration[ag_id_str]:
                self.minibatch_length = \
                    self.configuration[ag_id_str]['train_minibatch']

            if 'train_epochs' in self.configuration[ag_id_str]:
                self.train_epochs = \
                    self.configuration[ag_id_str]['train_epochs']

            if 'save_interval' in self.configuration[ag_id_str]:
                self.SAVE_INTERVAL = \
                    self.configuration[ag_id_str]['save_interval']

            if 'interaction_mode' in self.configuration['GENERAL']:
                self.INTERACTION_MODE = \
                    self.configuration['GENERAL']['interaction_mode']

            if 'use_gui' in self.configuration['GENERAL']:
                self.USE_GUI = self.configuration['GENERAL']['use_gui']

            if 'experience_logs' in self.configuration['GENERAL']:
                dialogues_path = None
                if 'path' in self.configuration['GENERAL']['experience_logs']:
                    dialogues_path = \
                        self.configuration['GENERAL'][
                            'experience_logs']['path']

                if 'load' in self.configuration['GENERAL']['experience_logs'] \
                        and bool(self.configuration['GENERAL']
                                 ['experience_logs']['load']
                                 ):
                    if dialogues_path and os.path.isfile(dialogues_path):
                        self.recorder.load(dialogues_path)
                    else:
                        raise FileNotFoundError(
                            'dialogue Log file %s not found (did you '
                            'provide one?)' % dialogues_path)

                if 'save' in self.configuration['GENERAL']['experience_logs']:
                    self.recorder.set_path(dialogues_path)
                    self.SAVE_LOG = bool(self.configuration['GENERAL']
                                         ['experience_logs']['save'])

            self.NModules = 0
            if 'modules' in self.configuration[ag_id_str]:
                self.NModules = int(self.configuration[ag_id_str]['modules'])

            if 'stateful_module' in self.configuration[ag_id_str]:
                self.STATEFUL_MODULE = int(
                    self.configuration[ag_id_str]['stateful_module'])

            # Note: Since we pass settings as a default argument, any
            #       module can access the global args. However, we
            #       add it here too for ease of use.
            self.global_arguments = {'settings': deepcopy(self.configuration)}
            if 'global_arguments' in self.configuration['GENERAL']:
                self.global_arguments.update(
                    self.configuration['GENERAL']['global_arguments'])

            # Load the goal generator, if any
            if 'GOAL_GENERATOR' in self.configuration[ag_id_str]:
                if 'package' not in \
                        self.configuration[ag_id_str]['GOAL_GENERATOR']:
                    raise ValueError(f'No package path provided for '
                                     f'goal generator!')
                elif 'class' not in \
                        self.configuration[ag_id_str]['GOAL_GENERATOR']:
                    raise ValueError(f'No class name provided for '
                                     f'goal generator!')
                else:
                    self.goal_generator = self.load_module(
                        self.configuration[ag_id_str]['GOAL_GENERATOR']
                        ['package'], self.configuration[ag_id_str]
                        ['GOAL_GENERATOR']['class'], self.global_arguments)

            # Load the modules
            for m in range(self.NModules):
                if 'MODULE_'+str(m) not in \
                        self.configuration[ag_id_str]:
                    raise ValueError(f'No MODULE_{m} section in config!')

                if 'parallel_modules' in self.configuration[ag_id_str][
                        'MODULE_' + str(m)]:

                    n_parallel_modules = self.configuration[ag_id_str][
                        'MODULE_' + str(m)]['parallel_modules']

                    parallel_modules = []

                    for pm in range(n_parallel_modules):
                        if 'package' not in self.configuration[ag_id_str][
                                'MODULE_' + str(m)]['PARALLEL_MODULE_' +
                                                    str(pm)]:
                            raise ValueError(
                                f'No arguments provided for parallel module '
                                f'{pm} of module {m}!')

                        package = self.configuration[ag_id_str][
                            'MODULE_' + str(m)]['PARALLEL_MODULE_' +
                                                str(pm)]['package']

                        if 'class' not in self.configuration[ag_id_str][
                                'MODULE_' + str(m)]['PARALLEL_MODULE_' +
                                                    str(pm)]:
                            raise ValueError(
                                f'No arguments provided for parallel module '
                                f'{pm} of module {m}!')

                        klass = self.configuration[ag_id_str][
                            'MODULE_' + str(m)]['PARALLEL_MODULE_' +
                                                str(pm)]['class']

                        # Append global arguments
                        # (add configuration by default)
                        args = deepcopy(self.global_arguments)
                        if 'arguments' in \
                                self.configuration[
                                    ag_id_str
                                ]['MODULE_' + str(m)][
                                    'PARALLEL_MODULE_' + str(pm)]:
                            args.update(self.configuration[ag_id_str][
                                'MODULE_' + str(m)]['PARALLEL_MODULE_' +
                                                    str(pm)]['arguments'])

                        parallel_modules.append(
                            self.load_module(package, klass, args))

                    self.ConversationalModules.append(parallel_modules)

                else:
                    if 'package' not in self.configuration[ag_id_str]['MODULE_'
                                                                      +
                                                                      str(m)]:
                        raise ValueError(f'No arguments provided for module '
                                         f'{m}!')

                    package = self.configuration[ag_id_str]['MODULE_' +
                                                            str(m)]['package']

                    if 'class' not in self.configuration[ag_id_str]['MODULE_' +
                                                                    str(m)]:
                        raise ValueError(f'No arguments provided for module '
                                         f'{m}!')

                    klass = self.configuration[ag_id_str]['MODULE_' +
                                                          str(m)]['class']

                    # Append global arguments (add configuration by default)
                    args = deepcopy(self.global_arguments)
                    if 'arguments' in \
                            self.configuration[
                                ag_id_str
                            ]['MODULE_' + str(m)]:
                        args.update(self.configuration[
                            'AGENT_' + str(agent_id)]['MODULE_' +
                                                      str(m)]['arguments'])

                    self.ConversationalModules.append(
                        self.load_module(package, klass, args))

        else:
            raise AttributeError('ConversationalGenericAgent: '
                                 'No settings (config) provided!')

        # TODO: Parse config modules I/O and raise error if
        #       any inconsistencies found

        # Initialize automatic speech recognizer, if necessary
        self.asr = None
        if self.INTERACTION_MODE == 'speech' and not self.USE_GUI:
            self.asr = speech_rec.Recognizer()
示例#6
0
class ConversationalGenericAgent(ConversationalAgent):
    """
    The ConversationalGenericAgent receives a list of modules in
    its configuration file, that are chained together serially -
    i.e. the input to the agent is passed to the first module,
    the first module's output is passed as input to the second
    module and so on. Modules are wrapped using ConversationalModules.
    The input and output passed between modules is wrapped into
    ConversationalFrames.
    """
    def __init__(self, configuration, agent_id):
        """
        Initialize the internal structures of this agent.

        :param configuration: a dictionary representing the configuration file
        :param agent_id: an integer, this agent's id
        """
        self.agent_id = agent_id

        # Dialogue statistics
        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0
        self.total_dialogue_turns = 0

        self.minibatch_length = 250
        self.train_interval = 50
        self.train_epochs = 10

        self.configuration = configuration

        self.recorder = DialogueEpisodeRecorder()

        self.SAVE_LOG = True
        self.SAVE_INTERVAL = 1000
        self.MAX_TURNS = 15
        self.INTERACTION_MODE = 'simulation'
        self.USE_GUI = False

        # This indicates which module controls the state so that we can query
        # it for dialogue termination (e.g. at end_dialogue)
        self.STATEFUL_MODULE = -1

        self.reward_func = SlotFillingGoalAdvancementReward()

        self.ConversationalModules = []
        self.prev_m_out = ConversationalFrame({})

        self.goal_generator = None
        self.agent_goal = None
        self.agent_role = ''

        ag_id_str = 'AGENT_' + str(agent_id)

        if self.configuration:
            if 'GENERAL' not in self.configuration:
                raise ValueError('No GENERAL section in config!')
            if 'AGENT_' + str(agent_id) not in self.configuration:
                raise ValueError(f'NO AGENT_{agent_id} section in config!')

            if 'role' in self.configuration[ag_id_str]:
                self.agent_role = self.configuration[ag_id_str]['role']

            # Retrieve agent parameters
            if 'max_turns' in self.configuration[ag_id_str]:
                self.MAX_TURNS = self.configuration[ag_id_str]['max_turns']

            if 'train_interval' in self.configuration[ag_id_str]:
                self.train_interval = \
                    self.configuration[ag_id_str]['train_interval']

            if 'train_minibatch' in self.configuration[ag_id_str]:
                self.minibatch_length = \
                    self.configuration[ag_id_str]['train_minibatch']

            if 'train_epochs' in self.configuration[ag_id_str]:
                self.train_epochs = \
                    self.configuration[ag_id_str]['train_epochs']

            if 'save_interval' in self.configuration[ag_id_str]:
                self.SAVE_INTERVAL = \
                    self.configuration[ag_id_str]['save_interval']

            if 'interaction_mode' in self.configuration['GENERAL']:
                self.INTERACTION_MODE = \
                    self.configuration['GENERAL']['interaction_mode']

            if 'use_gui' in self.configuration['GENERAL']:
                self.USE_GUI = self.configuration['GENERAL']['use_gui']

            if 'experience_logs' in self.configuration['GENERAL']:
                dialogues_path = None
                if 'path' in self.configuration['GENERAL']['experience_logs']:
                    dialogues_path = \
                        self.configuration['GENERAL'][
                            'experience_logs']['path']

                if 'load' in self.configuration['GENERAL']['experience_logs'] \
                        and bool(self.configuration['GENERAL']
                                 ['experience_logs']['load']
                                 ):
                    if dialogues_path and os.path.isfile(dialogues_path):
                        self.recorder.load(dialogues_path)
                    else:
                        raise FileNotFoundError(
                            'dialogue Log file %s not found (did you '
                            'provide one?)' % dialogues_path)

                if 'save' in self.configuration['GENERAL']['experience_logs']:
                    self.recorder.set_path(dialogues_path)
                    self.SAVE_LOG = bool(self.configuration['GENERAL']
                                         ['experience_logs']['save'])

            self.NModules = 0
            if 'modules' in self.configuration[ag_id_str]:
                self.NModules = int(self.configuration[ag_id_str]['modules'])

            if 'stateful_module' in self.configuration[ag_id_str]:
                self.STATEFUL_MODULE = int(
                    self.configuration[ag_id_str]['stateful_module'])

            # Note: Since we pass settings as a default argument, any
            #       module can access the global args. However, we
            #       add it here too for ease of use.
            self.global_arguments = {'settings': deepcopy(self.configuration)}
            if 'global_arguments' in self.configuration['GENERAL']:
                self.global_arguments.update(
                    self.configuration['GENERAL']['global_arguments'])

            # Load the goal generator, if any
            if 'GOAL_GENERATOR' in self.configuration[ag_id_str]:
                if 'package' not in \
                        self.configuration[ag_id_str]['GOAL_GENERATOR']:
                    raise ValueError(f'No package path provided for '
                                     f'goal generator!')
                elif 'class' not in \
                        self.configuration[ag_id_str]['GOAL_GENERATOR']:
                    raise ValueError(f'No class name provided for '
                                     f'goal generator!')
                else:
                    self.goal_generator = self.load_module(
                        self.configuration[ag_id_str]['GOAL_GENERATOR']
                        ['package'], self.configuration[ag_id_str]
                        ['GOAL_GENERATOR']['class'], self.global_arguments)

            # Load the modules
            for m in range(self.NModules):
                if 'MODULE_'+str(m) not in \
                        self.configuration[ag_id_str]:
                    raise ValueError(f'No MODULE_{m} section in config!')

                if 'parallel_modules' in self.configuration[ag_id_str][
                        'MODULE_' + str(m)]:

                    n_parallel_modules = self.configuration[ag_id_str][
                        'MODULE_' + str(m)]['parallel_modules']

                    parallel_modules = []

                    for pm in range(n_parallel_modules):
                        if 'package' not in self.configuration[ag_id_str][
                                'MODULE_' + str(m)]['PARALLEL_MODULE_' +
                                                    str(pm)]:
                            raise ValueError(
                                f'No arguments provided for parallel module '
                                f'{pm} of module {m}!')

                        package = self.configuration[ag_id_str][
                            'MODULE_' + str(m)]['PARALLEL_MODULE_' +
                                                str(pm)]['package']

                        if 'class' not in self.configuration[ag_id_str][
                                'MODULE_' + str(m)]['PARALLEL_MODULE_' +
                                                    str(pm)]:
                            raise ValueError(
                                f'No arguments provided for parallel module '
                                f'{pm} of module {m}!')

                        klass = self.configuration[ag_id_str][
                            'MODULE_' + str(m)]['PARALLEL_MODULE_' +
                                                str(pm)]['class']

                        # Append global arguments
                        # (add configuration by default)
                        args = deepcopy(self.global_arguments)
                        if 'arguments' in \
                                self.configuration[
                                    ag_id_str
                                ]['MODULE_' + str(m)][
                                    'PARALLEL_MODULE_' + str(pm)]:
                            args.update(self.configuration[ag_id_str][
                                'MODULE_' + str(m)]['PARALLEL_MODULE_' +
                                                    str(pm)]['arguments'])

                        parallel_modules.append(
                            self.load_module(package, klass, args))

                    self.ConversationalModules.append(parallel_modules)

                else:
                    if 'package' not in self.configuration[ag_id_str]['MODULE_'
                                                                      +
                                                                      str(m)]:
                        raise ValueError(f'No arguments provided for module '
                                         f'{m}!')

                    package = self.configuration[ag_id_str]['MODULE_' +
                                                            str(m)]['package']

                    if 'class' not in self.configuration[ag_id_str]['MODULE_' +
                                                                    str(m)]:
                        raise ValueError(f'No arguments provided for module '
                                         f'{m}!')

                    klass = self.configuration[ag_id_str]['MODULE_' +
                                                          str(m)]['class']

                    # Append global arguments (add configuration by default)
                    args = deepcopy(self.global_arguments)
                    if 'arguments' in \
                            self.configuration[
                                ag_id_str
                            ]['MODULE_' + str(m)]:
                        args.update(self.configuration[
                            'AGENT_' + str(agent_id)]['MODULE_' +
                                                      str(m)]['arguments'])

                    self.ConversationalModules.append(
                        self.load_module(package, klass, args))

        else:
            raise AttributeError('ConversationalGenericAgent: '
                                 'No settings (config) provided!')

        # TODO: Parse config modules I/O and raise error if
        #       any inconsistencies found

        # Initialize automatic speech recognizer, if necessary
        self.asr = None
        if self.INTERACTION_MODE == 'speech' and not self.USE_GUI:
            self.asr = speech_rec.Recognizer()

    def __del__(self):
        """
        Do some house-keeping, save the models.

        :return: nothing
        """

        if self.recorder and self.SAVE_LOG:
            self.recorder.save()

        for m in self.ConversationalModules:
            if isinstance(m, list):
                for sm in m:
                    sm.save()
            else:
                m.save()

    # Dynamically load classes
    @staticmethod
    def load_module(package_path, class_name, args):
        """
        Dynamically load the specified class.

        :param package_path: Path to the package to load
        :param class_name: Name of the class within the package
        :param args: arguments to pass when creating the object
        :return: the instantiated class object
        """
        module = __import__(package_path, fromlist=[class_name])
        klass = getattr(module, class_name)
        return klass(args)

    def initialize(self):
        """
        Initializes the conversational agent based on settings in the
        configuration file.

        :return: Nothing
        """

        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.cumulative_rewards = 0
        self.agent_goal = None

        # For each module
        for m in self.ConversationalModules:
            if isinstance(m, list):
                for sm in m:
                    sm.initialize({})
            else:
                # Load and initialize
                m.initialize({})

    def start_dialogue(self, args=None):
        """
        Reset or initialize internal structures at the beginning of the
        dialogue. May issue first utterance if this agent has the initiative.

        :param args:
        :return:
        """

        self.initialize()
        self.dialogue_turn = 0

        if args and 'goal' in args:
            self.agent_goal = deepcopy(args['goal'])

        elif self.goal_generator:
            self.agent_goal = self.goal_generator.generate()
            print(f'GOAL:\n=====\n{self.agent_goal}')

        # TODO: Get initial trigger from config
        if self.INTERACTION_MODE == 'dialogue_acts':
            self.prev_m_out = \
                ConversationalFrame([DialogueAct('hello')])
        else:
            self.prev_m_out = ConversationalFrame('hello')

        self.continue_dialogue(args)

        return {
            'input_utterance': None,
            'output_raw': self.prev_m_out.content,
            'output_dacts': '',
            'goal': self.agent_goal
        }

    def continue_dialogue(self, args=None):
        """
        Perform one dialogue turn

        :param args: input to this agent
        :return: output of this agent
        """

        utterance = None

        if self.INTERACTION_MODE == 'text' and not self.USE_GUI:
            utterance = input('USER > ')
            self.prev_m_out = ConversationalFrame(utterance)

        elif self.INTERACTION_MODE == 'speech' and not self.USE_GUI:
            # Listen for input from the microphone
            with speech_rec.Microphone() as source:
                print('(listening...)')
                audio = self.asr.listen(source, phrase_time_limit=3)

            try:
                # This uses the default key
                utterance = self.asr.recognize_google(audio)
                print("Google ASR: " + utterance)

                self.prev_m_out = ConversationalFrame(utterance)

            except speech_rec.UnknownValueError:
                print("Google ASR did not understand you")

            except speech_rec.RequestError as e:
                print("Google ASR request error: {0}".format(e))

        elif args and 'input' in args:
            self.prev_m_out = ConversationalFrame(args['input'])

        for m in self.ConversationalModules:
            # If executing parallel sub-modules
            if isinstance(m, list):
                idx = 0
                prev_m_out = deepcopy(self.prev_m_out)
                self.prev_m_out.content = {}

                for sm in m:
                    # WARNING! Module compatibility cannot be guaranteed here!
                    sm.generic_receive_input(prev_m_out)
                    sm_out = sm.generic_generate_output(prev_m_out)

                    if not isinstance(sm_out, ConversationalFrame):
                        sm_out = ConversationalFrame(sm_out)

                    self.prev_m_out.content['sm' + str(idx)] = sm_out.content
                    idx += 1

            else:
                # WARNING! Module compatibility cannot be guaranteed here!
                m.generic_receive_input(self.prev_m_out)
                self.prev_m_out = m.generic_generate_output(self.prev_m_out)

                # Make sure prev_m_out is a Conversational Frame
                if not isinstance(self.prev_m_out, ConversationalFrame):
                    self.prev_m_out = ConversationalFrame(self.prev_m_out)

            # DEBUG:
            if isinstance(self.prev_m_out.content, str):
                print(f'(DEBUG) {self.agent_role}> '
                      f'{str(self.prev_m_out.content)}')

        self.dialogue_turn += 1

        # In text or speech based interactions, return the input utterance as
        # it may be used for statistics or to show it to a GUI.
        return {
            'input_utterance': utterance,
            'output_raw': self.prev_m_out.content,
            'output_dacts': '',
            'goal': self.agent_goal
        }

    def end_dialogue(self):
        """
        Perform final dialogue turn. Save models if applicable.

        :return:
        """

        if self.dialogue_episode % self.train_interval == 0:
            for m in self.ConversationalModules:
                if isinstance(m, list):
                    for sm in m:
                        sm.train(self.recorder.dialogues)
                else:
                    m.train(self.recorder.dialogues)

        if self.dialogue_episode % self.SAVE_INTERVAL == 0:
            for m in self.ConversationalModules:
                if isinstance(m, list):
                    for sm in m:
                        sm.save()
                else:
                    m.save()

        # Keep track of dialogue statistics
        self.dialogue_episode += 1

        if self.dialogue_turn > 0:
            self.total_dialogue_turns += self.dialogue_turn

        # Count successful dialogues
        _, _, obj_succ = self.reward_func.calculate(
            self.get_state(),
            [],
            # TODO: In case of single agents, we actually need the user's goal
            goal=self.agent_goal,
            agent_role=self.agent_role)

        self.num_successful_dialogues += 1 if obj_succ else 0

    def terminated(self):
        """
        Check if this agent is at a terminal state.

        :return: True or False
        """

        return self.ConversationalModules[
                   self.STATEFUL_MODULE
               ].at_terminal_state() or \
            self.dialogue_turn > self.MAX_TURNS

    def set_goal(self, goal):
        """
        Set or update this agent's goal.

        :param goal: a Goal
        :return: nothing
        """

        self.agent_goal = goal

    def get_goal(self):
        """
        Get this agent's goal.

        :return: a Goal
        """

        return self.agent_goal

    def get_state(self):
        return self.ConversationalModules[self.STATEFUL_MODULE].get_state()