class ConversationalMultiAgent(ConversationalAgent):
    """
    Essentially the dialogue system. Will be able to interact with:

    - Simulated Users via:
        - Dialogue Acts
        - Text

    - Human Users via:
        - Text
        - Speech
        - Online crowd?

    - Data
    """
    def __init__(self, configuration, agent_id):
        """
        Initialize the internal structures of this agent.

        :param configuration: a dictionary representing the configuration file
        :param agent_id: an integer, this agent's id
        """

        super(ConversationalMultiAgent, self).__init__()

        self.agent_id = agent_id

        # Flag to alternate training between roles
        self.train_system = True

        # Dialogue statistics
        self.dialogue_episode = 1
        # Note that since this is a multi-agent setting, dialogue_turn refers
        # to this agent's turns only
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0
        self.total_dialogue_turns = 0

        self.minibatch_length = 200
        self.train_interval = 50
        self.train_epochs = 3

        # Alternate training between the agents
        self.train_alternate_training = True
        self.train_switch_trainable_agents_every = self.train_interval

        self.configuration = configuration

        # True values here would imply some default modules
        self.USE_USR_SIMULATOR = False
        self.USER_SIMULATOR_NLG = False
        self.USE_NLU = False
        self.USE_NLG = False
        self.USE_SPEECH = False
        self.USER_HAS_INITIATIVE = True
        self.SAVE_LOG = True

        # The dialogue will terminate after MAX_TURNS (this agent will issue
        # a bye() dialogue act.
        self.MAX_TURNS = 10

        self.ontology = None
        self.database = None
        self.domain = None
        self.dialogue_manager = None
        self.user_model = None
        self.nlu = None
        self.nlg = None

        self.agent_role = None
        self.agent_goal = None
        self.goal_generator = None
        self.goals_path = None

        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

        self.user_model = UserModel()

        # The size of the experience pool is a hyperparameter.

        # Do not have an experience window larger than the current batch,
        # as past experience may not be relevant since both agents learn.
        self.recorder = DialogueEpisodeRecorder(size=self.minibatch_length)

        # TODO: Get reward type from the config
        self.reward_func = SlotFillingReward()
        # self.reward_func = SlotFillingGoalAdvancementReward()

        if self.configuration:
            agent_id_str = 'AGENT_' + str(self.agent_id)

            # Error checks for options the config must have
            if not self.configuration['GENERAL']:
                raise ValueError('Cannot run Plato without GENERAL settings!')

            elif not self.configuration['GENERAL']['interaction_mode']:
                raise ValueError('Cannot run Plato without an interaction '
                                 'mode!')

            elif not self.configuration['DIALOGUE']:
                raise ValueError('Cannot run Plato without DIALOGUE settings!')

            elif not self.configuration[agent_id_str]:
                raise ValueError('Cannot run Plato without at least '
                                 'one agent!')

            # Dialogue domain self.settings
            if 'DIALOGUE' in self.configuration and \
                    self.configuration['DIALOGUE']:
                if 'initiative' in self.configuration['DIALOGUE']:
                    self.USER_HAS_INITIATIVE = bool(
                        self.configuration['DIALOGUE']['initiative'] == 'user')

                if self.configuration['DIALOGUE']['domain']:
                    self.domain = self.configuration['DIALOGUE']['domain']

                if self.configuration['DIALOGUE']['ontology_path']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['ontology_path']):
                        self.ontology = \
                            Ontology.Ontology(
                                self.configuration['DIALOGUE']['ontology_path']
                            )
                    else:
                        raise FileNotFoundError(
                            'Domain file %s not '
                            'found' %
                            self.configuration['DIALOGUE']['ontology_path'])

                if self.configuration['DIALOGUE']['db_path']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['db_path']):
                        if 'db_type' in self.configuration['DIALOGUE']:
                            if self.configuration['DIALOGUE']['db_type'] == \
                                    'sql':
                                self.database = DataBase.SQLDataBase(
                                    self.configuration['DIALOGUE']['db_path'])
                            else:
                                self.database = DataBase.DataBase(
                                    self.configuration['DIALOGUE']['db_path'])
                        else:
                            # Default to SQL
                            self.database = DataBase.SQLDataBase(
                                self.configuration['DIALOGUE']['db_path'])
                    else:
                        raise FileNotFoundError(
                            'Database file %s not '
                            'found' %
                            self.configuration['DIALOGUE']['db_path'])

                if 'goals_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['goals_path']):
                        self.goals_path = \
                            self.configuration['DIALOGUE']['goals_path']
                    else:
                        raise FileNotFoundError(
                            'Goals file %s not '
                            'found' %
                            self.configuration['DIALOGUE']['goals_path'])

            # General settings
            if 'GENERAL' in self.configuration and \
                    self.configuration['GENERAL']:
                if 'experience_logs' in self.configuration['GENERAL']:
                    dialogues_path = None
                    if 'path' in \
                            self.configuration['GENERAL']['experience_logs']:
                        dialogues_path = \
                            self.configuration['GENERAL'][
                                'experience_logs']['path']

                    if 'load' in \
                            self.configuration['GENERAL'][
                                'experience_logs'] and \
                            bool(
                                self.configuration[
                                    'GENERAL'
                                ]['experience_logs']['load']
                            ):
                        if dialogues_path and os.path.isfile(dialogues_path):
                            self.recorder.load(dialogues_path)
                        else:
                            raise FileNotFoundError(
                                'Dialogue Log file %s not found (did you '
                                'provide one?)' % dialogues_path)

                    if 'save' in \
                            self.configuration['GENERAL']['experience_logs']:
                        self.recorder.set_path(dialogues_path)
                        self.SAVE_LOG = bool(self.configuration['GENERAL']
                                             ['experience_logs']['save'])

                if self.configuration['GENERAL']['interaction_mode'] == \
                        'simulation':
                    self.USE_USR_SIMULATOR = True

            # NLU Settings
            if 'NLU' in self.configuration[agent_id_str] and \
                    self.configuration[agent_id_str]['NLU'] and \
                    self.configuration[agent_id_str]['NLU']['nlu']:
                nlu_args = dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))

                if self.configuration[agent_id_str]['NLU']['nlu'] == 'dummy':
                    self.nlu = DummyNLU(nlu_args)
                    self.USE_NLU = True

                elif self.configuration[agent_id_str]['NLU']['nlu'] == \
                        'CamRest':
                    if self.configuration[agent_id_str]['NLU']['model_path']:
                        nlu_args['model_path'] = \
                            self.configuration[
                                agent_id_str
                            ]['NLU']['model_path']
                        self.nlu = CamRestNLU(nlu_args)
                        self.USE_NLU = True
                    else:
                        raise ValueError(
                            'Cannot find model_path in the config.')

            # NLG Settings
            if 'NLG' in self.configuration[agent_id_str] and \
                    self.configuration[agent_id_str]['NLG'] and \
                    self.configuration[agent_id_str]['NLG']['nlg']:
                if self.configuration[agent_id_str]['NLG']['nlg'] == 'dummy':
                    self.nlg = DummyNLG()

                elif self.configuration[agent_id_str]['NLG']['nlg'] == \
                        'CamRest':
                    if self.configuration[agent_id_str]['NLG']['model_path']:
                        self.nlg = CamRestNLG({
                            'model_path':
                            self.configuration[agent_id_str]['NLG']
                            ['model_path']
                        })
                    else:
                        raise ValueError(
                            'Cannot find model_path in the config.')

                if self.nlg:
                    self.USE_NLG = True

            # Retrieve agent role
            if 'role' in self.configuration[agent_id_str]:
                self.agent_role = self.configuration[agent_id_str]['role']
            else:
                raise ValueError('ConversationalMultiAgent: No role assigned '
                                 'for agent {0} in '
                                 'config!'.format(self.agent_id))

            if self.agent_role == 'user':
                if self.ontology and self.database:
                    self.goal_generator = GoalGenerator(
                        ontology=self.ontology,
                        database=self.database,
                        goals_file=self.goals_path)
                else:
                    raise ValueError('Conversational Multi Agent (user): '
                                     'Cannot generate goal without ontology '
                                     'and database.')

        dm_args = dict(
            zip([
                'settings', 'ontology', 'database', 'domain', 'agent_id',
                'agent_role'
            ], [
                self.configuration, self.ontology, self.database, self.domain,
                self.agent_id, self.agent_role
            ]))

        dm_args.update(self.configuration['AGENT_' + str(agent_id)]['DM'])
        self.dialogue_manager = DialogueManager.DialogueManager(dm_args)

    def __del__(self):
        """
        Do some house-keeping and save the models.

        :return: nothing
        """

        if self.recorder and self.SAVE_LOG:
            self.recorder.save()

        if self.dialogue_manager:
            self.dialogue_manager.save()

        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None

    def initialize(self):
        """
        Initializes the conversational agent based on settings in the
        configuration file.

        :return: Nothing
        """

        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0

        if self.nlu:
            self.nlu.initialize({})

        self.dialogue_manager.initialize({})

        if self.nlg:
            self.nlg.initialize({})

        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

    def start_dialogue(self, goal=None):
        """
        Perform initial dialogue turn.

        :param goal: an optional Goal
        :return:
        """

        self.dialogue_turn = 0

        if self.agent_role == 'user':
            self.agent_goal = self.goal_generator.generate()
            self.dialogue_manager.update_goal(self.agent_goal)

            print('DEBUG > Usr goal:')
            for c in self.agent_goal.constraints:
                print(f'\t\tConstr({self.agent_goal.constraints[c].slot}='
                      f'{self.agent_goal.constraints[c].value})')
            print('\t\t-------------')
            for r in self.agent_goal.requests:
                print(f'\t\tReq({self.agent_goal.requests[r].slot})')
            print('\n')

        elif goal:
            # No deep copy here so that all agents see the same goal.
            self.agent_goal = goal
        else:
            raise ValueError('ConversationalMultiAgent - no goal provided '
                             'for agent {0}!'.format(self.agent_role))

        self.dialogue_manager.restart({'goal': self.agent_goal})

        response = [DialogueAct('welcomemsg', [])]
        response_utterance = ''

        # The above forces the DM's initial utterance to be the welcomemsg act.
        # The code below can be used to push an empty list of DialogueActs to
        # the DM and get a response from the model.
        # self.dialogue_manager.receive_input([])
        # response = self.dialogue_manager.respond()

        if self.agent_role == 'system':
            response = [DialogueAct('welcomemsg', [])]
            if self.USE_NLG:
                response_utterance = self.nlg.generate_output({
                    'dacts':
                    response,
                    'system':
                    self.agent_role == 'system',
                    'last_sys_utterance':
                    ''
                })

                print('{0} > {1}'.format(self.agent_role.upper(),
                                         response_utterance))

                if self.USE_SPEECH:
                    tts = gTTS(text=response_utterance, lang='en')
                    tts.save('sys_output.mp3')
                    os.system('afplay sys_output.mp3')
            else:
                print('{0} > {1}'.format(
                    self.agent_role.upper(),
                    '; '.join([str(sr) for sr in response])))

        # TODO: Generate output depending on initiative - i.e.
        # have users also start the dialogue

        self.prev_state = None

        # Re-initialize these for good measure
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None

        return response_utterance, response, self.agent_goal

    def continue_dialogue(self, args):
        """
        Perform next dialogue turn.

        :return: this agent's output
        """

        if 'other_input_raw' not in args:
            raise ValueError(
                'ConversationalMultiAgent called without raw input!')

        other_input_raw = args['other_input_raw']

        other_input_dact = None
        if 'other_input_dact' in args:
            other_input_dact = args['other_input_dact']

        goal = None
        if 'goal' in args:
            goal = args['goal']

        if goal:
            self.agent_goal = goal

        sys_utterance = ''

        other_input_nlu = deepcopy(other_input_raw)

        if self.nlu and isinstance(other_input_raw, str):
            # Process the other agent's utterance
            other_input_nlu = self.nlu.process_input(
                other_input_raw, self.dialogue_manager.get_state())

        elif other_input_dact:
            # If no utterance provided, use the dacts
            other_input_nlu = other_input_dact

        # print(
        #     '{0} Recognised Input: {1}'.format(
        #         self.agent_role.upper(),
        #         '; '.join([str(ui) for ui in other_input_nlu])
        #     )
        # )

        self.dialogue_manager.receive_input(other_input_nlu)

        # Keep track of prev_state, for the DialogueEpisodeRecorder
        # Store here because this is the state that the dialogue
        # manager will use to make a decision.
        self.curr_state = deepcopy(self.dialogue_manager.get_state())

        # Update goal's ground truth
        if self.agent_role == 'system':
            self.agent_goal.ground_truth = deepcopy(
                self.curr_state.item_in_focus)

        if self.dialogue_turn < self.MAX_TURNS:
            response = self.dialogue_manager.generate_output()
            self.agent_goal = self.dialogue_manager.DSTracker.DState.user_goal

        else:
            # Force dialogue stop
            # print('{0}: terminating dialogue due to too many turns'.
            #       format(self.agent_role))
            response = [DialogueAct('bye', [])]

        rew, success, task_success = self.reward_func.calculate(
            self.dialogue_manager.get_state(),
            response,
            goal=self.agent_goal,
            agent_role=self.agent_role)

        if self.USE_NLG:
            sys_utterance = self.nlg.generate_output(
                {
                    'dacts': response,
                    'system': self.agent_role == 'system',
                    'last_sys_utterance': other_input_raw
                }) + ' '

            print('{0} > {1}'.format(self.agent_role.upper(), sys_utterance))

            if self.USE_SPEECH:
                tts = gTTS(text=sys_utterance, lang='en')
                tts.save('sys_output.mp3')
                os.system('afplay sys_output.mp3')
        else:
            print('{0} > {1} \n'.format(
                self.agent_role.upper(),
                '; '.join([str(sr) for sr in response])))

        if self.prev_state:
            self.recorder.record(self.prev_state,
                                 self.curr_state,
                                 self.prev_action,
                                 self.prev_reward,
                                 self.prev_success,
                                 input_utterance=other_input_raw,
                                 output_utterance=sys_utterance,
                                 task_success=self.prev_task_success)

        self.dialogue_turn += 1

        self.prev_state = deepcopy(self.curr_state)
        self.prev_usr_utterance = deepcopy(other_input_raw)
        self.prev_sys_utterance = deepcopy(sys_utterance)
        self.prev_action = deepcopy(response)
        self.prev_reward = rew
        self.prev_success = success
        self.prev_task_success = task_success

        return sys_utterance, response, self.agent_goal

    def end_dialogue(self):
        """
        Perform final dialogue turn. Train and ave models if applicable.

        :return:
        """

        if self.dialogue_episode % \
                self.train_switch_trainable_agents_every == 0:
            self.train_system = not self.train_system

        # Record final state
        if not self.curr_state.is_terminal_state:
            self.curr_state.is_terminal_state = True
            self.prev_reward, self.prev_success, self.prev_task_success = \
                self.reward_func.calculate(
                    self.curr_state,
                    [DialogueAct('bye', [])],
                    goal=self.agent_goal,
                    agent_role=self.agent_role
                )

        self.recorder.record(self.curr_state,
                             self.curr_state,
                             self.prev_action,
                             self.prev_reward,
                             self.prev_success,
                             input_utterance=self.prev_usr_utterance,
                             output_utterance=self.prev_sys_utterance,
                             task_success=self.prev_task_success,
                             force_terminate=True)

        if self.dialogue_manager.is_training():
            if not self.train_alternate_training or \
                    (self.train_system and
                     self.agent_role == 'system' or
                     not self.train_system and
                     self.agent_role == 'user'):

                if (self.dialogue_episode+1) % self.train_interval == 0 and \
                        len(self.recorder.dialogues) >= self.minibatch_length:
                    for epoch in range(self.train_epochs):
                        print('{0}: Training epoch {1} of {2}'.format(
                            self.agent_role, (epoch + 1), self.train_epochs))

                        # Sample minibatch
                        minibatch = random.sample(self.recorder.dialogues,
                                                  self.minibatch_length)
                        self.dialogue_manager.train(minibatch)

        self.dialogue_episode += 1
        self.cumulative_rewards += \
            self.recorder.dialogues[-1][-1]['cumulative_reward']

        if self.dialogue_turn > 0:
            self.total_dialogue_turns += self.dialogue_turn

        if self.dialogue_episode % 10000 == 0:
            self.dialogue_manager.save()

        # Count successful dialogues
        if self.recorder.dialogues[-1][-1]['success']:
            print('{0} SUCCESS! (reward: {1})'.format(
                self.agent_role,
                sum([t['reward'] for t in self.recorder.dialogues[-1]])))
            self.num_successful_dialogues += \
                int(self.recorder.dialogues[-1][-1]['success'])

        else:
            print('{0} FAILURE. (reward: {1})'.format(
                self.agent_role,
                sum([t['reward'] for t in self.recorder.dialogues[-1]])))

        if self.recorder.dialogues[-1][-1]['task_success']:
            self.num_task_success += int(
                self.recorder.dialogues[-1][-1]['task_success'])

    def terminated(self):
        """
        Check if this agent is at a terminal state.

        :return: True or False
        """

        # Hard coded response to bye to enforce policy according to which
        # if any agent issues a 'bye' then the dialogue
        # terminates. Otherwise in multi-agent settings it is very hard to
        # learn the association and learn to terminate
        # the dialogue.
        if self.dialogue_manager.get_state().user_acts:
            for act in self.dialogue_manager.get_state().user_acts:
                if act.intent == 'bye':
                    return True

        return self.dialogue_manager.at_terminal_state()

    def get_goal(self):
        """
        Get this agent's goal

        :return: a Goal
        """
        return self.agent_goal

    def set_goal(self, goal):
        """
        Set or update this agent's goal

        :param goal: a Goal
        :return: nothing
        """

        # TODO: Deep copy?
        # Note: reason for non-deep copy is that if this agent changes the goal
        #       these changes are propagated to e.g. the reward function, and
        #       the reward calculation is up to date.
        self.agent_goal = goal
    def __init__(self, configuration, agent_id):
        """
        Initialize the internal structures of this agent.

        :param configuration: a dictionary representing the configuration file
        :param agent_id: an integer, this agent's id
        """

        super(ConversationalMultiAgent, self).__init__()

        self.agent_id = agent_id

        # Flag to alternate training between roles
        self.train_system = True

        # Dialogue statistics
        self.dialogue_episode = 1
        # Note that since this is a multi-agent setting, dialogue_turn refers
        # to this agent's turns only
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0
        self.total_dialogue_turns = 0

        self.minibatch_length = 200
        self.train_interval = 50
        self.train_epochs = 3

        # Alternate training between the agents
        self.train_alternate_training = True
        self.train_switch_trainable_agents_every = self.train_interval

        self.configuration = configuration

        # True values here would imply some default modules
        self.USE_USR_SIMULATOR = False
        self.USER_SIMULATOR_NLG = False
        self.USE_NLU = False
        self.USE_NLG = False
        self.USE_SPEECH = False
        self.USER_HAS_INITIATIVE = True
        self.SAVE_LOG = True

        # The dialogue will terminate after MAX_TURNS (this agent will issue
        # a bye() dialogue act.
        self.MAX_TURNS = 10

        self.ontology = None
        self.database = None
        self.domain = None
        self.dialogue_manager = None
        self.user_model = None
        self.nlu = None
        self.nlg = None

        self.agent_role = None
        self.agent_goal = None
        self.goal_generator = None
        self.goals_path = None

        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

        self.user_model = UserModel()

        # The size of the experience pool is a hyperparameter.

        # Do not have an experience window larger than the current batch,
        # as past experience may not be relevant since both agents learn.
        self.recorder = DialogueEpisodeRecorder(size=self.minibatch_length)

        # TODO: Get reward type from the config
        self.reward_func = SlotFillingReward()
        # self.reward_func = SlotFillingGoalAdvancementReward()

        if self.configuration:
            agent_id_str = 'AGENT_' + str(self.agent_id)

            # Error checks for options the config must have
            if not self.configuration['GENERAL']:
                raise ValueError('Cannot run Plato without GENERAL settings!')

            elif not self.configuration['GENERAL']['interaction_mode']:
                raise ValueError('Cannot run Plato without an interaction '
                                 'mode!')

            elif not self.configuration['DIALOGUE']:
                raise ValueError('Cannot run Plato without DIALOGUE settings!')

            elif not self.configuration[agent_id_str]:
                raise ValueError('Cannot run Plato without at least '
                                 'one agent!')

            # Dialogue domain self.settings
            if 'DIALOGUE' in self.configuration and \
                    self.configuration['DIALOGUE']:
                if 'initiative' in self.configuration['DIALOGUE']:
                    self.USER_HAS_INITIATIVE = bool(
                        self.configuration['DIALOGUE']['initiative'] == 'user')

                if self.configuration['DIALOGUE']['domain']:
                    self.domain = self.configuration['DIALOGUE']['domain']

                if self.configuration['DIALOGUE']['ontology_path']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['ontology_path']):
                        self.ontology = \
                            Ontology.Ontology(
                                self.configuration['DIALOGUE']['ontology_path']
                            )
                    else:
                        raise FileNotFoundError(
                            'Domain file %s not '
                            'found' %
                            self.configuration['DIALOGUE']['ontology_path'])

                if self.configuration['DIALOGUE']['db_path']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['db_path']):
                        if 'db_type' in self.configuration['DIALOGUE']:
                            if self.configuration['DIALOGUE']['db_type'] == \
                                    'sql':
                                self.database = DataBase.SQLDataBase(
                                    self.configuration['DIALOGUE']['db_path'])
                            else:
                                self.database = DataBase.DataBase(
                                    self.configuration['DIALOGUE']['db_path'])
                        else:
                            # Default to SQL
                            self.database = DataBase.SQLDataBase(
                                self.configuration['DIALOGUE']['db_path'])
                    else:
                        raise FileNotFoundError(
                            'Database file %s not '
                            'found' %
                            self.configuration['DIALOGUE']['db_path'])

                if 'goals_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['goals_path']):
                        self.goals_path = \
                            self.configuration['DIALOGUE']['goals_path']
                    else:
                        raise FileNotFoundError(
                            'Goals file %s not '
                            'found' %
                            self.configuration['DIALOGUE']['goals_path'])

            # General settings
            if 'GENERAL' in self.configuration and \
                    self.configuration['GENERAL']:
                if 'experience_logs' in self.configuration['GENERAL']:
                    dialogues_path = None
                    if 'path' in \
                            self.configuration['GENERAL']['experience_logs']:
                        dialogues_path = \
                            self.configuration['GENERAL'][
                                'experience_logs']['path']

                    if 'load' in \
                            self.configuration['GENERAL'][
                                'experience_logs'] and \
                            bool(
                                self.configuration[
                                    'GENERAL'
                                ]['experience_logs']['load']
                            ):
                        if dialogues_path and os.path.isfile(dialogues_path):
                            self.recorder.load(dialogues_path)
                        else:
                            raise FileNotFoundError(
                                'Dialogue Log file %s not found (did you '
                                'provide one?)' % dialogues_path)

                    if 'save' in \
                            self.configuration['GENERAL']['experience_logs']:
                        self.recorder.set_path(dialogues_path)
                        self.SAVE_LOG = bool(self.configuration['GENERAL']
                                             ['experience_logs']['save'])

                if self.configuration['GENERAL']['interaction_mode'] == \
                        'simulation':
                    self.USE_USR_SIMULATOR = True

            # NLU Settings
            if 'NLU' in self.configuration[agent_id_str] and \
                    self.configuration[agent_id_str]['NLU'] and \
                    self.configuration[agent_id_str]['NLU']['nlu']:
                nlu_args = dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))

                if self.configuration[agent_id_str]['NLU']['nlu'] == 'dummy':
                    self.nlu = DummyNLU(nlu_args)
                    self.USE_NLU = True

                elif self.configuration[agent_id_str]['NLU']['nlu'] == \
                        'CamRest':
                    if self.configuration[agent_id_str]['NLU']['model_path']:
                        nlu_args['model_path'] = \
                            self.configuration[
                                agent_id_str
                            ]['NLU']['model_path']
                        self.nlu = CamRestNLU(nlu_args)
                        self.USE_NLU = True
                    else:
                        raise ValueError(
                            'Cannot find model_path in the config.')

            # NLG Settings
            if 'NLG' in self.configuration[agent_id_str] and \
                    self.configuration[agent_id_str]['NLG'] and \
                    self.configuration[agent_id_str]['NLG']['nlg']:
                if self.configuration[agent_id_str]['NLG']['nlg'] == 'dummy':
                    self.nlg = DummyNLG()

                elif self.configuration[agent_id_str]['NLG']['nlg'] == \
                        'CamRest':
                    if self.configuration[agent_id_str]['NLG']['model_path']:
                        self.nlg = CamRestNLG({
                            'model_path':
                            self.configuration[agent_id_str]['NLG']
                            ['model_path']
                        })
                    else:
                        raise ValueError(
                            'Cannot find model_path in the config.')

                if self.nlg:
                    self.USE_NLG = True

            # Retrieve agent role
            if 'role' in self.configuration[agent_id_str]:
                self.agent_role = self.configuration[agent_id_str]['role']
            else:
                raise ValueError('ConversationalMultiAgent: No role assigned '
                                 'for agent {0} in '
                                 'config!'.format(self.agent_id))

            if self.agent_role == 'user':
                if self.ontology and self.database:
                    self.goal_generator = GoalGenerator(
                        ontology=self.ontology,
                        database=self.database,
                        goals_file=self.goals_path)
                else:
                    raise ValueError('Conversational Multi Agent (user): '
                                     'Cannot generate goal without ontology '
                                     'and database.')

        dm_args = dict(
            zip([
                'settings', 'ontology', 'database', 'domain', 'agent_id',
                'agent_role'
            ], [
                self.configuration, self.ontology, self.database, self.domain,
                self.agent_id, self.agent_role
            ]))

        dm_args.update(self.configuration['AGENT_' + str(agent_id)]['DM'])
        self.dialogue_manager = DialogueManager.DialogueManager(dm_args)
    def __init__(self, configuration):
        """
        Initialize the internal structures of this agent.

        :param configuration: a dictionary representing the configuration file
        :param agent_id: an integer, this agent's id
        """

        super(ConversationalSingleAgent, self).__init__()

        self.configuration = configuration

        # There is only one agent in this setting
        self.agent_id = 0

        # Dialogue statistics
        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0
        self.total_dialogue_turns = 0

        self.minibatch_length = 500
        self.train_interval = 50
        self.train_epochs = 10

        # True values here would imply some default modules
        self.USE_USR_SIMULATOR = False
        self.USER_SIMULATOR_NLU = False
        self.USER_SIMULATOR_NLG = False
        self.USE_NLG = False
        self.USE_SPEECH = False
        self.USER_HAS_INITIATIVE = True
        self.SAVE_LOG = True

        # The dialogue will terminate after MAX_TURNS (this agent will issue
        # a bye() dialogue act.
        self.MAX_TURNS = 15

        self.dialogue_turn = -1
        self.ontology = None
        self.database = None
        self.domain = None
        self.dialogue_manager = None
        self.user_model = None
        self.user_simulator = None
        self.user_simulator_args = {}
        self.nlu = None
        self.nlg = None

        self.agent_role = None
        self.agent_goal = None
        self.goal_generator = None

        self.curr_state = None
        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

        self.user_model = UserModel()

        self.recorder = DialogueEpisodeRecorder()

        # TODO: Handle this properly - get reward function type from config
        self.reward_func = SlotFillingReward()
        # self.reward_func = SlotFillingGoalAdvancementReward()

        if self.configuration:
            # Error checks for options the config must have
            if not self.configuration['GENERAL']:
                raise ValueError('Cannot run Plato without GENERAL settings!')

            elif not self.configuration['GENERAL']['interaction_mode']:
                raise ValueError('Cannot run Plato without an '
                                 'interaction mode!')

            elif not self.configuration['DIALOGUE']:
                raise ValueError('Cannot run Plato without DIALOGUE settings!')

            elif not self.configuration['AGENT_0']:
                raise ValueError('Cannot run Plato without at least '
                                 'one agent!')

            # Dialogue domain self.settings
            if 'DIALOGUE' in self.configuration and \
                    self.configuration['DIALOGUE']:
                if 'initiative' in self.configuration['DIALOGUE']:
                    self.USER_HAS_INITIATIVE = bool(
                        self.configuration['DIALOGUE']['initiative'] == 'user')
                    self.user_simulator_args['us_has_initiative'] = \
                        self.USER_HAS_INITIATIVE

                if self.configuration['DIALOGUE']['domain']:
                    self.domain = self.configuration['DIALOGUE']['domain']

                if self.configuration['DIALOGUE']['ontology_path']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['ontology_path']):
                        self.ontology = Ontology.Ontology(
                            self.configuration['DIALOGUE']['ontology_path'])
                    else:
                        raise FileNotFoundError(
                            'Domain file %s not found' %
                            self.configuration['DIALOGUE']['ontology_path'])

                if self.configuration['DIALOGUE']['db_path']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['db_path']):
                        if 'db_type' in self.configuration['DIALOGUE']:
                            if self.configuration['DIALOGUE']['db_type'] == \
                                    'sql':
                                self.database = DataBase.SQLDataBase(
                                    self.configuration['DIALOGUE']['db_path'])
                            else:
                                self.database = DataBase.DataBase(
                                    self.configuration['DIALOGUE']['db_path'])
                        else:
                            # Default to SQL
                            self.database = DataBase.SQLDataBase(
                                self.configuration['DIALOGUE']['db_path'])
                    else:
                        raise FileNotFoundError(
                            'Database file %s not found' %
                            self.configuration['DIALOGUE']['db_path'])

                if 'goals_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['goals_path']):
                        self.goals_path = \
                            self.configuration['DIALOGUE']['goals_path']
                    else:
                        raise FileNotFoundError(
                            'Goals file %s not found' %
                            self.configuration['DIALOGUE']['goals_path'])

            # General settings
            if 'GENERAL' in self.configuration and \
                    self.configuration['GENERAL']:
                if 'experience_logs' in self.configuration['GENERAL']:
                    dialogues_path = None
                    if 'path' in \
                            self.configuration['GENERAL']['experience_logs']:
                        dialogues_path = \
                            self.configuration['GENERAL'][
                                'experience_logs']['path']

                    if 'load' in \
                            self.configuration['GENERAL']['experience_logs'] \
                        and bool(
                            self.configuration['GENERAL'][
                                'experience_logs']['load']
                    ):
                        if dialogues_path and os.path.isfile(dialogues_path):
                            self.recorder.load(dialogues_path)
                        else:
                            raise FileNotFoundError(
                                'Dialogue Log file %s not found (did you '
                                'provide one?)' % dialogues_path)

                    if 'save' in \
                            self.configuration['GENERAL']['experience_logs']:
                        self.recorder.set_path(dialogues_path)
                        self.SAVE_LOG = bool(self.configuration['GENERAL']
                                             ['experience_logs']['save'])

                if self.configuration['GENERAL']['interaction_mode'] == \
                        'simulation':
                    self.USE_USR_SIMULATOR = True

                elif self.configuration['GENERAL']['interaction_mode'] == \
                        'speech':
                    self.USE_SPEECH = True
                    self.asr = speech_rec.Recognizer()

            # Agent Settings

            # Usr Simulator
            # Check for specific simulator self.settings, otherwise
            # default to agenda
            if 'USER_SIMULATOR' in self.configuration['AGENT_0']:
                # Agent 0 simulator configuration
                a0_sim_config = self.configuration['AGENT_0']['USER_SIMULATOR']
                if a0_sim_config and a0_sim_config['simulator']:
                    # Default settings
                    self.user_simulator_args['ontology'] = self.ontology
                    self.user_simulator_args['database'] = self.database
                    self.user_simulator_args['um'] = self.user_model
                    self.user_simulator_args['patience'] = 5

                    if a0_sim_config['simulator'] == 'agenda':
                        if 'patience' in a0_sim_config:
                            self.user_simulator_args['patience'] = \
                                int(a0_sim_config['patience'])

                        if 'pop_distribution' in a0_sim_config:
                            if isinstance(a0_sim_config['pop_distribution'],
                                          list):
                                self.user_simulator_args['pop_distribution'] =\
                                    a0_sim_config['pop_distribution']
                            else:
                                self.user_simulator_args['pop_distribution'] =\
                                    eval(a0_sim_config['pop_distribution'])

                        if 'slot_confuse_prob' in a0_sim_config:
                            self.user_simulator_args['slot_confuse_prob'] = \
                                float(a0_sim_config['slot_confuse_prob'])
                        if 'op_confuse_prob' in a0_sim_config:
                            self.user_simulator_args['op_confuse_prob'] = \
                                float(a0_sim_config['op_confuse_prob'])
                        if 'value_confuse_prob' in a0_sim_config:
                            self.user_simulator_args['value_confuse_prob'] = \
                                float(a0_sim_config['value_confuse_prob'])

                        if 'goal_slot_selection_weights' in a0_sim_config:
                            self.user_simulator_args[
                                'goal_slot_selection_weights'] = a0_sim_config[
                                    'goal_slot_selection_weights']

                        if 'nlu' in a0_sim_config:
                            self.user_simulator_args['nlu'] = \
                                a0_sim_config['nlu']

                            if self.user_simulator_args['nlu'] == 'dummy':
                                self.user_simulator_args['database'] = \
                                    self.database

                            self.USER_SIMULATOR_NLU = True

                        if 'nlg' in a0_sim_config:
                            self.user_simulator_args['nlg'] = \
                                a0_sim_config['nlg']

                            if self.user_simulator_args['nlg'] == 'CamRest':
                                if a0_sim_config:
                                    self.user_simulator_args[
                                        'nlg_model_path'] = a0_sim_config[
                                            'nlg_model_path']

                                    self.USER_SIMULATOR_NLG = True

                                else:
                                    raise ValueError(
                                        'Usr Simulator NLG: Cannot find '
                                        'model_path in the config.')

                            elif self.user_simulator_args['nlg'] == 'dummy':
                                self.USER_SIMULATOR_NLG = True

                        if 'goals_file' in a0_sim_config:
                            self.user_simulator_args['goals_file'] = \
                                a0_sim_config['goals_file']

                        if 'policy_file' in a0_sim_config:
                            self.user_simulator_args['policy_file'] = \
                                a0_sim_config['policy_file']

                        self.user_simulator = AgendaBasedUS(
                            self.user_simulator_args)

                    elif a0_sim_config['simulator'] == 'dtl':
                        if 'policy_file' in a0_sim_config:
                            self.user_simulator_args['policy_file'] = \
                                a0_sim_config['policy_file']
                            self.user_simulator = DTLUserSimulator(
                                self.user_simulator_args)
                        else:
                            raise ValueError(
                                'Error! Cannot start DAct-to-Language '
                                'simulator without a policy file!')

                else:
                    # Fallback to agenda based simulator with default settings
                    self.user_simulator = AgendaBasedUS(
                        self.user_simulator_args)

            # NLU Settings
            if 'NLU' in self.configuration['AGENT_0'] and \
                    self.configuration['AGENT_0']['NLU'] and \
                    self.configuration['AGENT_0']['NLU']['nlu']:
                nlu_args = dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))

                if self.configuration['AGENT_0']['NLU']['nlu'] == 'dummy':
                    self.nlu = DummyNLU(nlu_args)

                elif self.configuration['AGENT_0']['NLU']['nlu'] == 'CamRest':
                    if self.configuration['AGENT_0']['NLU']['model_path']:
                        nlu_args['model_path'] = \
                            self.configuration['AGENT_0']['NLU']['model_path']
                        self.nlu = CamRestNLU(nlu_args)
                    else:
                        raise ValueError(
                            'Cannot find model_path in the config.')

            # NLG Settings
            if 'NLG' in self.configuration['AGENT_0'] and \
                    self.configuration['AGENT_0']['NLG'] and \
                    self.configuration['AGENT_0']['NLG']['nlg']:
                if self.configuration['AGENT_0']['NLG']['nlg'] == 'dummy':
                    self.nlg = DummyNLG()

                elif self.configuration['AGENT_0']['NLG']['nlg'] == 'CamRest':
                    if self.configuration['AGENT_0']['NLG']['model_path']:
                        self.nlg = CamRestNLG({
                            'model_path':
                            self.configuration['AGENT_0']['NLG']['model_path']
                        })
                    else:
                        raise ValueError(
                            'Cannot find model_path in the config.')

                if self.nlg:
                    self.USE_NLG = True

            # Retrieve agent role
            if 'role' in self.configuration['AGENT_0']:
                self.agent_role = self.configuration['AGENT_0']['role']
            else:
                raise ValueError(
                    'ConversationalAgent: No role assigned for agent {0} in '
                    'config!'.format(self.agent_id))

            if self.agent_role == 'user':
                if self.ontology and self.database:
                    self.goal_generator = GoalGenerator(ontology=self.ontology,
                                                        database=self.database)
                else:
                    raise ValueError(
                        'Conversational Multi Agent (user): Cannot generate '
                        'goal without ontology and database.')

        dm_args = dict(
            zip([
                'settings', 'ontology', 'database', 'domain', 'agent_id',
                'agent_role'
            ], [
                self.configuration, self.ontology, self.database, self.domain,
                self.agent_id, self.agent_role
            ]))
        dm_args.update(self.configuration['AGENT_0']['DM'])
        self.dialogue_manager = DialogueManager.DialogueManager(dm_args)
class ConversationalSingleAgent(ConversationalAgent):
    """
    Essentially the dialogue system. Will be able to interact with:

    - Simulated Users via:
        - Dialogue Acts
        - Text

    - Human Users via:
        - Text
        - Speech
        - Online crowd?

    - Data
    """
    def __init__(self, configuration):
        """
        Initialize the internal structures of this agent.

        :param configuration: a dictionary representing the configuration file
        :param agent_id: an integer, this agent's id
        """

        super(ConversationalSingleAgent, self).__init__()

        self.configuration = configuration

        # There is only one agent in this setting
        self.agent_id = 0

        # Dialogue statistics
        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0
        self.total_dialogue_turns = 0

        self.minibatch_length = 500
        self.train_interval = 50
        self.train_epochs = 10

        # True values here would imply some default modules
        self.USE_USR_SIMULATOR = False
        self.USER_SIMULATOR_NLU = False
        self.USER_SIMULATOR_NLG = False
        self.USE_NLG = False
        self.USE_SPEECH = False
        self.USER_HAS_INITIATIVE = True
        self.SAVE_LOG = True

        # The dialogue will terminate after MAX_TURNS (this agent will issue
        # a bye() dialogue act.
        self.MAX_TURNS = 15

        self.dialogue_turn = -1
        self.ontology = None
        self.database = None
        self.domain = None
        self.dialogue_manager = None
        self.user_model = None
        self.user_simulator = None
        self.user_simulator_args = {}
        self.nlu = None
        self.nlg = None

        self.agent_role = None
        self.agent_goal = None
        self.goal_generator = None

        self.curr_state = None
        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

        self.user_model = UserModel()

        self.recorder = DialogueEpisodeRecorder()

        # TODO: Handle this properly - get reward function type from config
        self.reward_func = SlotFillingReward()
        # self.reward_func = SlotFillingGoalAdvancementReward()

        if self.configuration:
            # Error checks for options the config must have
            if not self.configuration['GENERAL']:
                raise ValueError('Cannot run Plato without GENERAL settings!')

            elif not self.configuration['GENERAL']['interaction_mode']:
                raise ValueError('Cannot run Plato without an '
                                 'interaction mode!')

            elif not self.configuration['DIALOGUE']:
                raise ValueError('Cannot run Plato without DIALOGUE settings!')

            elif not self.configuration['AGENT_0']:
                raise ValueError('Cannot run Plato without at least '
                                 'one agent!')

            # Dialogue domain self.settings
            if 'DIALOGUE' in self.configuration and \
                    self.configuration['DIALOGUE']:
                if 'initiative' in self.configuration['DIALOGUE']:
                    self.USER_HAS_INITIATIVE = bool(
                        self.configuration['DIALOGUE']['initiative'] == 'user')
                    self.user_simulator_args['us_has_initiative'] = \
                        self.USER_HAS_INITIATIVE

                if self.configuration['DIALOGUE']['domain']:
                    self.domain = self.configuration['DIALOGUE']['domain']

                if self.configuration['DIALOGUE']['ontology_path']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['ontology_path']):
                        self.ontology = Ontology.Ontology(
                            self.configuration['DIALOGUE']['ontology_path'])
                    else:
                        raise FileNotFoundError(
                            'Domain file %s not found' %
                            self.configuration['DIALOGUE']['ontology_path'])

                if self.configuration['DIALOGUE']['db_path']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['db_path']):
                        if 'db_type' in self.configuration['DIALOGUE']:
                            if self.configuration['DIALOGUE']['db_type'] == \
                                    'sql':
                                self.database = DataBase.SQLDataBase(
                                    self.configuration['DIALOGUE']['db_path'])
                            else:
                                self.database = DataBase.DataBase(
                                    self.configuration['DIALOGUE']['db_path'])
                        else:
                            # Default to SQL
                            self.database = DataBase.SQLDataBase(
                                self.configuration['DIALOGUE']['db_path'])
                    else:
                        raise FileNotFoundError(
                            'Database file %s not found' %
                            self.configuration['DIALOGUE']['db_path'])

                if 'goals_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['goals_path']):
                        self.goals_path = \
                            self.configuration['DIALOGUE']['goals_path']
                    else:
                        raise FileNotFoundError(
                            'Goals file %s not found' %
                            self.configuration['DIALOGUE']['goals_path'])

            # General settings
            if 'GENERAL' in self.configuration and \
                    self.configuration['GENERAL']:
                if 'experience_logs' in self.configuration['GENERAL']:
                    dialogues_path = None
                    if 'path' in \
                            self.configuration['GENERAL']['experience_logs']:
                        dialogues_path = \
                            self.configuration['GENERAL'][
                                'experience_logs']['path']

                    if 'load' in \
                            self.configuration['GENERAL']['experience_logs'] \
                        and bool(
                            self.configuration['GENERAL'][
                                'experience_logs']['load']
                    ):
                        if dialogues_path and os.path.isfile(dialogues_path):
                            self.recorder.load(dialogues_path)
                        else:
                            raise FileNotFoundError(
                                'Dialogue Log file %s not found (did you '
                                'provide one?)' % dialogues_path)

                    if 'save' in \
                            self.configuration['GENERAL']['experience_logs']:
                        self.recorder.set_path(dialogues_path)
                        self.SAVE_LOG = bool(self.configuration['GENERAL']
                                             ['experience_logs']['save'])

                if self.configuration['GENERAL']['interaction_mode'] == \
                        'simulation':
                    self.USE_USR_SIMULATOR = True

                elif self.configuration['GENERAL']['interaction_mode'] == \
                        'speech':
                    self.USE_SPEECH = True
                    self.asr = speech_rec.Recognizer()

            # Agent Settings

            # Usr Simulator
            # Check for specific simulator self.settings, otherwise
            # default to agenda
            if 'USER_SIMULATOR' in self.configuration['AGENT_0']:
                # Agent 0 simulator configuration
                a0_sim_config = self.configuration['AGENT_0']['USER_SIMULATOR']
                if a0_sim_config and a0_sim_config['simulator']:
                    # Default settings
                    self.user_simulator_args['ontology'] = self.ontology
                    self.user_simulator_args['database'] = self.database
                    self.user_simulator_args['um'] = self.user_model
                    self.user_simulator_args['patience'] = 5

                    if a0_sim_config['simulator'] == 'agenda':
                        if 'patience' in a0_sim_config:
                            self.user_simulator_args['patience'] = \
                                int(a0_sim_config['patience'])

                        if 'pop_distribution' in a0_sim_config:
                            if isinstance(a0_sim_config['pop_distribution'],
                                          list):
                                self.user_simulator_args['pop_distribution'] =\
                                    a0_sim_config['pop_distribution']
                            else:
                                self.user_simulator_args['pop_distribution'] =\
                                    eval(a0_sim_config['pop_distribution'])

                        if 'slot_confuse_prob' in a0_sim_config:
                            self.user_simulator_args['slot_confuse_prob'] = \
                                float(a0_sim_config['slot_confuse_prob'])
                        if 'op_confuse_prob' in a0_sim_config:
                            self.user_simulator_args['op_confuse_prob'] = \
                                float(a0_sim_config['op_confuse_prob'])
                        if 'value_confuse_prob' in a0_sim_config:
                            self.user_simulator_args['value_confuse_prob'] = \
                                float(a0_sim_config['value_confuse_prob'])

                        if 'goal_slot_selection_weights' in a0_sim_config:
                            self.user_simulator_args[
                                'goal_slot_selection_weights'] = a0_sim_config[
                                    'goal_slot_selection_weights']

                        if 'nlu' in a0_sim_config:
                            self.user_simulator_args['nlu'] = \
                                a0_sim_config['nlu']

                            if self.user_simulator_args['nlu'] == 'dummy':
                                self.user_simulator_args['database'] = \
                                    self.database

                            self.USER_SIMULATOR_NLU = True

                        if 'nlg' in a0_sim_config:
                            self.user_simulator_args['nlg'] = \
                                a0_sim_config['nlg']

                            if self.user_simulator_args['nlg'] == 'CamRest':
                                if a0_sim_config:
                                    self.user_simulator_args[
                                        'nlg_model_path'] = a0_sim_config[
                                            'nlg_model_path']

                                    self.USER_SIMULATOR_NLG = True

                                else:
                                    raise ValueError(
                                        'Usr Simulator NLG: Cannot find '
                                        'model_path in the config.')

                            elif self.user_simulator_args['nlg'] == 'dummy':
                                self.USER_SIMULATOR_NLG = True

                        if 'goals_file' in a0_sim_config:
                            self.user_simulator_args['goals_file'] = \
                                a0_sim_config['goals_file']

                        if 'policy_file' in a0_sim_config:
                            self.user_simulator_args['policy_file'] = \
                                a0_sim_config['policy_file']

                        self.user_simulator = AgendaBasedUS(
                            self.user_simulator_args)

                    elif a0_sim_config['simulator'] == 'dtl':
                        if 'policy_file' in a0_sim_config:
                            self.user_simulator_args['policy_file'] = \
                                a0_sim_config['policy_file']
                            self.user_simulator = DTLUserSimulator(
                                self.user_simulator_args)
                        else:
                            raise ValueError(
                                'Error! Cannot start DAct-to-Language '
                                'simulator without a policy file!')

                else:
                    # Fallback to agenda based simulator with default settings
                    self.user_simulator = AgendaBasedUS(
                        self.user_simulator_args)

            # NLU Settings
            if 'NLU' in self.configuration['AGENT_0'] and \
                    self.configuration['AGENT_0']['NLU'] and \
                    self.configuration['AGENT_0']['NLU']['nlu']:
                nlu_args = dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))

                if self.configuration['AGENT_0']['NLU']['nlu'] == 'dummy':
                    self.nlu = DummyNLU(nlu_args)

                elif self.configuration['AGENT_0']['NLU']['nlu'] == 'CamRest':
                    if self.configuration['AGENT_0']['NLU']['model_path']:
                        nlu_args['model_path'] = \
                            self.configuration['AGENT_0']['NLU']['model_path']
                        self.nlu = CamRestNLU(nlu_args)
                    else:
                        raise ValueError(
                            'Cannot find model_path in the config.')

            # NLG Settings
            if 'NLG' in self.configuration['AGENT_0'] and \
                    self.configuration['AGENT_0']['NLG'] and \
                    self.configuration['AGENT_0']['NLG']['nlg']:
                if self.configuration['AGENT_0']['NLG']['nlg'] == 'dummy':
                    self.nlg = DummyNLG()

                elif self.configuration['AGENT_0']['NLG']['nlg'] == 'CamRest':
                    if self.configuration['AGENT_0']['NLG']['model_path']:
                        self.nlg = CamRestNLG({
                            'model_path':
                            self.configuration['AGENT_0']['NLG']['model_path']
                        })
                    else:
                        raise ValueError(
                            'Cannot find model_path in the config.')

                if self.nlg:
                    self.USE_NLG = True

            # Retrieve agent role
            if 'role' in self.configuration['AGENT_0']:
                self.agent_role = self.configuration['AGENT_0']['role']
            else:
                raise ValueError(
                    'ConversationalAgent: No role assigned for agent {0} in '
                    'config!'.format(self.agent_id))

            if self.agent_role == 'user':
                if self.ontology and self.database:
                    self.goal_generator = GoalGenerator(ontology=self.ontology,
                                                        database=self.database)
                else:
                    raise ValueError(
                        'Conversational Multi Agent (user): Cannot generate '
                        'goal without ontology and database.')

        dm_args = dict(
            zip([
                'settings', 'ontology', 'database', 'domain', 'agent_id',
                'agent_role'
            ], [
                self.configuration, self.ontology, self.database, self.domain,
                self.agent_id, self.agent_role
            ]))
        dm_args.update(self.configuration['AGENT_0']['DM'])
        self.dialogue_manager = DialogueManager.DialogueManager(dm_args)

    def __del__(self):
        """
        Do some house-keeping and save the models.

        :return: nothing
        """

        if self.recorder and self.SAVE_LOG:
            self.recorder.save()

        if self.dialogue_manager:
            self.dialogue_manager.save()

        self.curr_state = None
        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

    def initialize(self):
        """
        Initializes the conversational agent based on settings in the
        configuration file.

        :return: Nothing
        """

        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0

        if self.nlu:
            self.nlu.initialize({})

        if self.agent_role == 'user' and not self.agent_goal:
            self.agent_goal = self.goal_generator.generate()
            self.dialogue_manager.initialize({'goal': self.agent_goal})

        else:
            self.dialogue_manager.initialize({})

        if self.nlg:
            self.nlg.initialize({})

        self.curr_state = None
        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

    def start_dialogue(self, args=None):
        """
        Perform initial dialogue turn.

        :param args: optional args
        :return:
        """

        self.dialogue_turn = 0
        sys_utterance = ''

        if self.USE_USR_SIMULATOR:
            self.user_simulator.initialize(self.user_simulator_args)

            print('DEBUG > Usr goal:')
            print(self.user_simulator.goal)

        if self.agent_role == 'user':
            self.agent_goal = self.goal_generator.generate()
            self.dialogue_manager.restart({'goal': self.agent_goal})

        else:
            self.dialogue_manager.restart({})

        if not self.USER_HAS_INITIATIVE:
            # sys_response = self.dialogue_manager.respond()
            sys_response = [DialogueAct('welcomemsg', [])]

            if self.USE_NLG:
                sys_utterance = self.nlg.generate_output(
                    {'dacts': sys_response})
                print('SYSTEM > %s ' % sys_utterance)

                if self.USE_SPEECH:
                    try:
                        tts = gTTS(sys_utterance)
                        tts.save('sys_output.mp3')
                        os.system('afplay sys_output.mp3')

                    except Exception as e:
                        print('WARNING: gTTS encountered an error: {0}. '
                              'Falling back to Sys TTS.'.format(e))
                        os.system('say ' + sys_utterance)
            else:
                print('SYSTEM > %s ' %
                      '; '.join([str(sr) for sr in sys_response]))

            if self.USE_USR_SIMULATOR:
                usim_input = sys_response

                if self.USER_SIMULATOR_NLU and self.USE_NLG:
                    usim_input = self.user_simulator.nlu.process_input(
                        sys_utterance)

                self.user_simulator.receive_input(usim_input)
                rew, success, task_success = self.reward_func.calculate(
                    self.dialogue_manager.get_state(), sys_response,
                    self.user_simulator.goal)
            else:
                rew, success, task_success = 0, None, None

            self.recorder.record(deepcopy(self.dialogue_manager.get_state()),
                                 self.dialogue_manager.get_state(),
                                 sys_response,
                                 rew,
                                 success,
                                 task_success,
                                 output_utterance=sys_utterance)

            self.dialogue_turn += 1

        self.prev_state = None

        # Re-initialize these for good measure
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

        self.continue_dialogue()

    def continue_dialogue(self):
        """
        Perform next dialogue turn.

        :return: nothing
        """

        usr_utterance = ''
        sys_utterance = ''

        if self.USE_USR_SIMULATOR:
            usr_input = self.user_simulator.respond()

            # TODO: THIS FIRST IF WILL BE HANDLED BY ConversationalAgentGeneric
            #  -- SHOULD NOT LIVE HERE
            if isinstance(self.user_simulator, DTLUserSimulator):
                print('USER (NLG) > %s \n' % usr_input)
                usr_input = self.nlu.process_input(
                    usr_input, self.dialogue_manager.get_state())

            elif self.USER_SIMULATOR_NLG:
                print('USER > %s \n' % usr_input)

                if self.nlu:
                    usr_input = self.nlu.process_input(usr_input)

                    # Otherwise it will just print the user's NLG but use the
                    # simulator's output DActs to proceed.

            else:
                print('USER (DACT) > %s \n' % usr_input[0])

        else:
            if self.USE_SPEECH:
                # Listen for input from the microphone
                with speech_rec.Microphone() as source:
                    print('(listening...)')
                    audio = self.asr.listen(source, phrase_time_limit=3)

                try:
                    # This uses the default key
                    usr_utterance = self.asr.recognize_google(audio)
                    print("Google ASR: " + usr_utterance)

                except speech_rec.UnknownValueError:
                    print("Google ASR did not understand you")

                except speech_rec.RequestError as e:
                    print("Google ASR request error: {0}".format(e))

            else:
                usr_utterance = input('USER > ')

            # Process the user's utterance
            if self.nlu:
                usr_input = self.nlu.process_input(
                    usr_utterance, self.dialogue_manager.get_state())
            else:
                raise EnvironmentError(
                    'ConversationalAgent: No NLU defined for '
                    'text-based interaction!')

        # DEBUG print
        # print(
        #     '\nSYSTEM NLU > %s ' % '; '.join([str(ui) for ui in usr_input])
        # )

        self.dialogue_manager.receive_input(usr_input)

        # Keep track of prev_state, for the DialogueEpisodeRecorder
        # Store here because this is the state that the dialogue manager
        # will use to make a decision.
        self.curr_state = deepcopy(self.dialogue_manager.get_state())

        # print('\nDEBUG> '+str(self.dialogue_manager.get_state()) + '\n')

        if self.dialogue_turn < self.MAX_TURNS:
            sys_response = self.dialogue_manager.generate_output()

        else:
            # Force dialogue stop
            # print(
            #     '{0}: terminating dialogue due to too '
            #     'many turns'.format(self.agent_role)
            # )
            sys_response = [DialogueAct('bye', [])]

        if self.USE_NLG:
            sys_utterance = self.nlg.generate_output({'dacts': sys_response})
            print('SYSTEM > %s ' % sys_utterance)

            if self.USE_SPEECH:
                try:
                    tts = gTTS(text=sys_utterance, lang='en')
                    tts.save('sys_output.mp3')
                    os.system('afplay sys_output.mp3')

                except:
                    print('WARNING: gTTS encountered an error. '
                          'Falling back to Sys TTS.')
                    os.system('say ' + sys_utterance)
        else:
            print('SYSTEM > %s ' % '; '.join([str(sr) for sr in sys_response]))

        if self.USE_USR_SIMULATOR:
            usim_input = sys_response

            if self.USER_SIMULATOR_NLU and self.USE_NLG:
                usim_input = \
                    self.user_simulator.nlu.process_input(sys_utterance)

                print('USER NLU '
                      '> %s ' % '; '.join([str(ui) for ui in usim_input]))

            self.user_simulator.receive_input(usim_input)
            rew, success, task_success = \
                self.reward_func.calculate(
                    self.dialogue_manager.get_state(),
                    sys_response,
                    self.user_simulator.goal
                )
        else:
            rew, success, task_success = 0, None, None

        if self.prev_state:
            self.recorder.record(self.prev_state,
                                 self.curr_state,
                                 self.prev_action,
                                 self.prev_reward,
                                 self.prev_success,
                                 input_utterance=usr_utterance,
                                 output_utterance=sys_utterance)

        self.dialogue_turn += 1

        self.prev_state = deepcopy(self.curr_state)
        self.prev_action = deepcopy(sys_response)
        self.prev_usr_utterance = deepcopy(usr_utterance)
        self.prev_sys_utterance = deepcopy(sys_utterance)
        self.prev_reward = rew
        self.prev_success = success
        self.prev_task_success = task_success

    def end_dialogue(self):
        """
        Perform final dialogue turn. Train and save models if applicable.

        :return: nothing
        """

        # Record final state
        self.recorder.record(self.curr_state,
                             self.curr_state,
                             self.prev_action,
                             self.prev_reward,
                             self.prev_success,
                             input_utterance=self.prev_usr_utterance,
                             output_utterance=self.prev_sys_utterance,
                             task_success=self.prev_task_success)

        if self.dialogue_manager.is_training():
            if self.dialogue_episode % self.train_interval == 0 and \
                    len(self.recorder.dialogues) >= self.minibatch_length:
                for epoch in range(self.train_epochs):
                    print('Training epoch {0} of {1}'.format(
                        epoch, self.train_epochs))

                    # Sample minibatch
                    minibatch = random.sample(self.recorder.dialogues,
                                              self.minibatch_length)

                    if self.nlu:
                        self.nlu.train(minibatch)

                    self.dialogue_manager.train(minibatch)

                    if self.nlg:
                        self.nlg.train(minibatch)

        self.dialogue_episode += 1
        self.cumulative_rewards += \
            self.recorder.dialogues[-1][-1]['cumulative_reward']
        print('CUMULATIVE REWARD: {0}'.format(
            self.recorder.dialogues[-1][-1]['cumulative_reward']))

        if self.dialogue_turn > 0:
            self.total_dialogue_turns += self.dialogue_turn

        if self.dialogue_episode % 10000 == 0:
            self.dialogue_manager.save()

        # Count successful dialogues
        if self.recorder.dialogues[-1][-1]['success']:
            print('SUCCESS (Subjective)!')
            self.num_successful_dialogues += \
                int(self.recorder.dialogues[-1][-1]['success'])

        else:
            print('FAILURE (Subjective).')

        if self.recorder.dialogues[-1][-1]['task_success']:
            self.num_task_success += \
                int(self.recorder.dialogues[-1][-1]['task_success'])

        print('OBJECTIVE TASK SUCCESS: {0}'.format(
            self.recorder.dialogues[-1][-1]['task_success']))

    def terminated(self):
        """
        Check if this agent is at a terminal state.

        :return: True or False
        """

        return self.dialogue_manager.at_terminal_state()
Ejemplo n.º 5
0
    def __init__(self, configuration, agent_id):
        """
        Initialize the internal structures of this agent.

        :param configuration: a dictionary representing the configuration file
        :param agent_id: an integer, this agent's id
        """
        self.agent_id = agent_id

        # Dialogue statistics
        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0
        self.total_dialogue_turns = 0

        self.minibatch_length = 250
        self.TRAIN_INTERVAL = 50
        self.train_epochs = 10

        self.configuration = configuration

        self.recorder = DialogueEpisodeRecorder()

        self.SAVE_LOG = True
        self.SAVE_INTERVAL = 10000
        self.MAX_TURNS = 15
        self.INTERACTION_MODE = 'simulation'

        self.reward_func = SlotFillingGoalAdvancementReward()

        self.ConversationalModules = []
        self.prev_m_out = ConversationalFrame({})

        self.goal_generator = None
        self.agent_goal = None

        if self.configuration:
            if 'GENERAL' not in self.configuration:
                raise ValueError('No GENERAL section in config!')
            if 'AGENT_' + str(agent_id) not in self.configuration:
                raise ValueError(f'NO AGENT_{agent_id} section in config!')

            if 'interaction_mode' in self.configuration['GENERAL']:
                self.INTERACTION_MODE = \
                    self.configuration['GENERAL']['interaction_mode']

            if 'experience_logs' in self.configuration['GENERAL']:
                dialogues_path = None
                if 'path' in self.configuration['GENERAL']['experience_logs']:
                    dialogues_path = \
                        self.configuration['GENERAL'][
                            'experience_logs']['path']

                if 'load' in self.configuration['GENERAL']['experience_logs'] \
                        and bool(self.configuration['GENERAL']
                                 ['experience_logs']['load']
                                 ):
                    if dialogues_path and os.path.isfile(dialogues_path):
                        self.recorder.load(dialogues_path)
                    else:
                        raise FileNotFoundError(
                            'Dialogue Log file %s not found (did you '
                            'provide one?)' % dialogues_path)

                if 'save' in self.configuration['GENERAL']['experience_logs']:
                    self.recorder.set_path(dialogues_path)
                    self.SAVE_LOG = bool(self.configuration['GENERAL']
                                         ['experience_logs']['save'])

            self.NModules = 0
            if 'modules' in self.configuration['AGENT_' + str(agent_id)]:
                self.NModules = int(
                    self.configuration['AGENT_' + str(agent_id)]['modules'])

            # Note: Since we pass settings as a default argument, any
            #       module can access the global args. However, we
            #       add it here too for ease of use.
            self.global_arguments = {'settings': self.configuration}
            if 'global_arguments' in self.configuration['GENERAL']:
                self.global_arguments.update(
                    self.configuration['GENERAL']['global_arguments'])

            # Load the modules
            for m in range(self.NModules):
                if 'MODULE_'+str(m) not in \
                        self.configuration['AGENT_' + str(agent_id)]:
                    raise ValueError(f'No MODULE_{m} section in config!')

                if 'parallel_modules' in self.configuration[
                        'AGENT_' + str(agent_id)]['MODULE_' + str(m)]:

                    n_parallel_modules = self.configuration[
                        'AGENT_' + str(agent_id)]['MODULE_' +
                                                  str(m)]['parallel_modules']

                    parallel_modules = []

                    for pm in range(n_parallel_modules):
                        if 'package' not in self.configuration[
                                'AGENT_' +
                                str(agent_id)]['MODULE_' +
                                               str(m)]['PARALLEL_MODULE_' +
                                                       str(pm)]:
                            raise ValueError(
                                f'No arguments provided for parallel module '
                                f'{pm} of module {m}!')

                        package = self.configuration['AGENT_' + str(agent_id)][
                            'MODULE_' + str(m)]['PARALLEL_MODULE_' +
                                                str(pm)]['package']

                        if 'class' not in self.configuration[
                                'AGENT_' +
                                str(agent_id)]['MODULE_' +
                                               str(m)]['PARALLEL_MODULE_' +
                                                       str(pm)]:
                            raise ValueError(
                                f'No arguments provided for parallel module '
                                f'{pm} of module {m}!')

                        klass = self.configuration['AGENT_' + str(agent_id)][
                            'MODULE_' + str(m)]['PARALLEL_MODULE_' +
                                                str(pm)]['class']

                        # Append global arguments
                        # (add configuration by default)
                        args = deepcopy(self.global_arguments)
                        if 'arguments' in \
                                self.configuration[
                                    'AGENT_' + str(agent_id)
                                ]['MODULE_' + str(m)][
                                    'PARALLEL_MODULE_' + str(pm)]:
                            args.update(
                                self.configuration['AGENT_' + str(agent_id)][
                                    'MODULE_' + str(m)]['PARALLEL_MODULE_' +
                                                        str(pm)]['arguments'])

                        parallel_modules.append(
                            self.load_module(package, klass, args))

                    self.ConversationalModules.append(parallel_modules)

                else:
                    if 'package' not in self.configuration[
                            'AGENT_' + str(agent_id)]['MODULE_' + str(m)]:
                        raise ValueError(f'No arguments provided for module '
                                         f'{m}!')

                    package = self.configuration['AGENT_' + str(agent_id)][
                        'MODULE_' + str(m)]['package']

                    if 'class' not in self.configuration[
                            'AGENT_' + str(agent_id)]['MODULE_' + str(m)]:
                        raise ValueError(f'No arguments provided for module '
                                         f'{m}!')

                    klass = self.configuration['AGENT_' +
                                               str(agent_id)]['MODULE_' +
                                                              str(m)]['class']

                    # Append global arguments (add configuration by default)
                    args = deepcopy(self.global_arguments)
                    if 'arguments' in \
                            self.configuration[
                                'AGENT_' + str(agent_id)
                            ]['MODULE_' + str(m)]:
                        args.update(self.configuration[
                            'AGENT_' + str(agent_id)]['MODULE_' +
                                                      str(m)]['arguments'])

                    self.ConversationalModules.append(
                        self.load_module(package, klass, args))

        else:
            raise AttributeError('ConversationalGenericAgent: '
                                 'No settings (config) provided!')
Ejemplo n.º 6
0
class ConversationalGenericAgent(ConversationalAgent):
    """
    The ConversationalGenericAgent receives a list of modules in
    its configuration file, that are chained together serially -
    i.e. the input to the agent is passed to the first module,
    the first module's output is passed as input to the second
    module and so on. Modules are wrapped using ConversationalModules.
    The input and output passed between modules is wrapped into
    ConversationalFrames.
    """
    def __init__(self, configuration, agent_id):
        """
        Initialize the internal structures of this agent.

        :param configuration: a dictionary representing the configuration file
        :param agent_id: an integer, this agent's id
        """
        self.agent_id = agent_id

        # Dialogue statistics
        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0
        self.total_dialogue_turns = 0

        self.minibatch_length = 250
        self.TRAIN_INTERVAL = 50
        self.train_epochs = 10

        self.configuration = configuration

        self.recorder = DialogueEpisodeRecorder()

        self.SAVE_LOG = True
        self.SAVE_INTERVAL = 10000
        self.MAX_TURNS = 15
        self.INTERACTION_MODE = 'simulation'

        self.reward_func = SlotFillingGoalAdvancementReward()

        self.ConversationalModules = []
        self.prev_m_out = ConversationalFrame({})

        self.goal_generator = None
        self.agent_goal = None

        if self.configuration:
            if 'GENERAL' not in self.configuration:
                raise ValueError('No GENERAL section in config!')
            if 'AGENT_' + str(agent_id) not in self.configuration:
                raise ValueError(f'NO AGENT_{agent_id} section in config!')

            if 'interaction_mode' in self.configuration['GENERAL']:
                self.INTERACTION_MODE = \
                    self.configuration['GENERAL']['interaction_mode']

            if 'experience_logs' in self.configuration['GENERAL']:
                dialogues_path = None
                if 'path' in self.configuration['GENERAL']['experience_logs']:
                    dialogues_path = \
                        self.configuration['GENERAL'][
                            'experience_logs']['path']

                if 'load' in self.configuration['GENERAL']['experience_logs'] \
                        and bool(self.configuration['GENERAL']
                                 ['experience_logs']['load']
                                 ):
                    if dialogues_path and os.path.isfile(dialogues_path):
                        self.recorder.load(dialogues_path)
                    else:
                        raise FileNotFoundError(
                            'Dialogue Log file %s not found (did you '
                            'provide one?)' % dialogues_path)

                if 'save' in self.configuration['GENERAL']['experience_logs']:
                    self.recorder.set_path(dialogues_path)
                    self.SAVE_LOG = bool(self.configuration['GENERAL']
                                         ['experience_logs']['save'])

            self.NModules = 0
            if 'modules' in self.configuration['AGENT_' + str(agent_id)]:
                self.NModules = int(
                    self.configuration['AGENT_' + str(agent_id)]['modules'])

            # Note: Since we pass settings as a default argument, any
            #       module can access the global args. However, we
            #       add it here too for ease of use.
            self.global_arguments = {'settings': self.configuration}
            if 'global_arguments' in self.configuration['GENERAL']:
                self.global_arguments.update(
                    self.configuration['GENERAL']['global_arguments'])

            # Load the modules
            for m in range(self.NModules):
                if 'MODULE_'+str(m) not in \
                        self.configuration['AGENT_' + str(agent_id)]:
                    raise ValueError(f'No MODULE_{m} section in config!')

                if 'parallel_modules' in self.configuration[
                        'AGENT_' + str(agent_id)]['MODULE_' + str(m)]:

                    n_parallel_modules = self.configuration[
                        'AGENT_' + str(agent_id)]['MODULE_' +
                                                  str(m)]['parallel_modules']

                    parallel_modules = []

                    for pm in range(n_parallel_modules):
                        if 'package' not in self.configuration[
                                'AGENT_' +
                                str(agent_id)]['MODULE_' +
                                               str(m)]['PARALLEL_MODULE_' +
                                                       str(pm)]:
                            raise ValueError(
                                f'No arguments provided for parallel module '
                                f'{pm} of module {m}!')

                        package = self.configuration['AGENT_' + str(agent_id)][
                            'MODULE_' + str(m)]['PARALLEL_MODULE_' +
                                                str(pm)]['package']

                        if 'class' not in self.configuration[
                                'AGENT_' +
                                str(agent_id)]['MODULE_' +
                                               str(m)]['PARALLEL_MODULE_' +
                                                       str(pm)]:
                            raise ValueError(
                                f'No arguments provided for parallel module '
                                f'{pm} of module {m}!')

                        klass = self.configuration['AGENT_' + str(agent_id)][
                            'MODULE_' + str(m)]['PARALLEL_MODULE_' +
                                                str(pm)]['class']

                        # Append global arguments
                        # (add configuration by default)
                        args = deepcopy(self.global_arguments)
                        if 'arguments' in \
                                self.configuration[
                                    'AGENT_' + str(agent_id)
                                ]['MODULE_' + str(m)][
                                    'PARALLEL_MODULE_' + str(pm)]:
                            args.update(
                                self.configuration['AGENT_' + str(agent_id)][
                                    'MODULE_' + str(m)]['PARALLEL_MODULE_' +
                                                        str(pm)]['arguments'])

                        parallel_modules.append(
                            self.load_module(package, klass, args))

                    self.ConversationalModules.append(parallel_modules)

                else:
                    if 'package' not in self.configuration[
                            'AGENT_' + str(agent_id)]['MODULE_' + str(m)]:
                        raise ValueError(f'No arguments provided for module '
                                         f'{m}!')

                    package = self.configuration['AGENT_' + str(agent_id)][
                        'MODULE_' + str(m)]['package']

                    if 'class' not in self.configuration[
                            'AGENT_' + str(agent_id)]['MODULE_' + str(m)]:
                        raise ValueError(f'No arguments provided for module '
                                         f'{m}!')

                    klass = self.configuration['AGENT_' +
                                               str(agent_id)]['MODULE_' +
                                                              str(m)]['class']

                    # Append global arguments (add configuration by default)
                    args = deepcopy(self.global_arguments)
                    if 'arguments' in \
                            self.configuration[
                                'AGENT_' + str(agent_id)
                            ]['MODULE_' + str(m)]:
                        args.update(self.configuration[
                            'AGENT_' + str(agent_id)]['MODULE_' +
                                                      str(m)]['arguments'])

                    self.ConversationalModules.append(
                        self.load_module(package, klass, args))

        else:
            raise AttributeError('ConversationalGenericAgent: '
                                 'No settings (config) provided!')

        # TODO: Parse config modules I/O and raise error if
        #       any inconsistencies found

    def __del__(self):
        """
        Do some house-keeping, save the models.

        :return: nothing
        """

        if self.recorder and self.SAVE_LOG:
            self.recorder.save()

        for m in self.ConversationalModules:
            if isinstance(m, list):
                for sm in m:
                    sm.save()
            else:
                m.save()

    # Dynamically load classes
    @staticmethod
    def load_module(package_path, class_name, args):
        """
        Dynamically load the specified class.

        :param package_path: Path to the package to load
        :param class_name: Name of the class within the package
        :param args: arguments to pass when creating the object
        :return: the instantiated class object
        """
        module = __import__(package_path, fromlist=[class_name])
        klass = getattr(module, class_name)
        return klass(args)

    def initialize(self):
        """
        Initializes the conversational agent based on settings in the
        configuration file.

        :return: Nothing
        """

        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0
        self.agent_goal = None

        # For each module
        for m in self.ConversationalModules:
            if isinstance(m, list):
                for sm in m:
                    sm.initialize({})
            else:
                # Load and initialize
                m.initialize({})

    def start_dialogue(self, args=None):
        """
        Reset or initialize internal structures at the beginning of the
        dialogue. May issue first utterance if this agent has the initiative.

        :param args:
        :return:
        """

        self.initialize()
        self.dialogue_turn = 0
        # TODO: Get initial trigger from config
        self.prev_m_out = ConversationalFrame({'utterance': 'Hello'})
        self.continue_dialogue()

        return self.prev_m_out.content, '', self.agent_goal

    def continue_dialogue(self, args=None):
        """
        Perform one dialogue turn

        :param args: input to this agent
        :return: output of this agent
        """

        if self.INTERACTION_MODE == 'text':
            self.prev_m_out = input('USER > ')

        for m in self.ConversationalModules:
            # If executing parallel sub-modules
            if isinstance(m, list):
                idx = 0
                prev_m_out = deepcopy(self.prev_m_out)
                self.prev_m_out.content = {}

                for sm in m:
                    # WARNING! Module compatibility cannot be guaranteed here!
                    sm.generic_receive_input(prev_m_out)
                    sm_out = sm.generic_generate_output(prev_m_out)

                    if not isinstance(sm_out, ConversationalFrame):
                        sm_out = ConversationalFrame(sm_out)

                    self.prev_m_out.content['sm' + str(idx)] = sm_out.content
                    idx += 1

            else:
                # WARNING! Module compatibility cannot be guaranteed here!
                m.generic_receive_input(self.prev_m_out)
                self.prev_m_out = m.generic_generate_output(self.prev_m_out)

                # Make sure prev_m_out is a Conversational Frame
                if not isinstance(self.prev_m_out, ConversationalFrame):
                    self.prev_m_out = ConversationalFrame(self.prev_m_out)

            # DEBUG:
            if isinstance(self.prev_m_out.content, str):
                print('DEBUG> ' + str(self.prev_m_out.content))

        self.dialogue_turn += 1

        return self.prev_m_out.content, '', self.agent_goal

    def end_dialogue(self):
        """
        Perform final dialogue turn. Save models if applicable.

        :return:
        """

        if self.dialogue_episode % self.TRAIN_INTERVAL == 0:
            for m in self.ConversationalModules:
                if isinstance(m, list):
                    for sm in m:
                        sm.train(self.recorder.dialogues)
                else:
                    m.train(self.recorder.dialogues)

        if self.dialogue_episode % self.SAVE_INTERVAL == 0:
            for m in self.ConversationalModules:
                if isinstance(m, list):
                    for sm in m:
                        sm.save()
                else:
                    m.save()

        self.dialogue_episode += 1

    def terminated(self):
        """
        Check if this agent is at a terminal state.

        :return: True or False
        """

        # TODO: Set at config which module controls the state
        return self.ConversationalModules[-1].at_terminal_state() or \
            self.dialogue_turn > self.MAX_TURNS

    def set_goal(self, goal):
        """
        Set or update this agent's goal.

        :param goal: a Goal
        :return: nothing
        """

        self.agent_goal = goal

    def get_goal(self):
        """
        Get this agent's goal.

        :return: a Goal
        """

        return self.agent_goal
    args = {'path': data_path, 'ontology': ontology, 'database': database}

    parser.initialize(**args)

    print('Parsing {0}'.format(args['path']))

    parser.parse_data()

    print('Data parsing complete.')

    # Save data

    parser.save('Logs')

    # Load data
    recorder_sys = DialogueEpisodeRecorder(path='Logs/DSTC2_system')
    recorder_usr = DialogueEpisodeRecorder(path='Logs/DSTC2_user')

    # Train Supervised Models using the recorded data
    system_policy_supervised = SupervisedPolicy(ontology,
                                                database,
                                                agent_role='system',
                                                agent_id=0,
                                                domain='CamRest')

    user_policy_supervised = SupervisedPolicy(ontology,
                                              database,
                                              agent_role='user',
                                              agent_id=1,
                                              domain='CamRest')