def initialize(self, args):
        """
        Initialize internal structures at the beginning of each dialogue

        :return: Nothing
        """

        if self.agent_role == 'system':
            # Put your system expert dialogue policy here
            self.warmup_policy = HandcraftedPolicy({'ontology': self.ontology})

        elif self.agent_role == 'user':
            usim_args = \
                dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))
            # Put your user expert dialogue policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        if 'is_training' in args:
            self.is_training = bool(args['is_training'])

            if self.agent_role == 'user' and self.warmup_simulator:
                if 'goal' in args:
                    self.warmup_simulator.initialize({args['goal']})
                else:
                    print('WARNING ! No goal provided for Supervised policy '
                          'user simulator @ initialize')
                    self.warmup_simulator.initialize({})

        if 'policy_path' in args:
            self.policy_path = args['policy_path']

        if 'learning_rate' in args:
            self.policy_alpha = args['learning_rate']

        if self.sess is None:
            self.policy_net = self.feed_forward_net_init()
            self.sess = tf.InteractiveSession()
            self.sess.run(tf.global_variables_initializer())

            self.tf_saver = \
                tf.train.Saver(var_list=tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES, scope=self.tf_scope))
예제 #2
0
    def __init__(self, configuration):
        """
        Initialize the internal structures of this agent.

        :param configuration: a dictionary representing the configuration file
        """

        super(ConversationalSingleAgent, self).__init__()

        self.configuration = configuration

        # There is only one agent in this setting
        self.agent_id = 0

        # dialogue statistics
        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0
        self.total_dialogue_turns = 0

        # Default meta-parameter values
        self.minibatch_length = 500
        self.train_interval = 50
        self.train_epochs = 3

        # True values here would imply some default modules
        self.USE_USR_SIMULATOR = False
        self.USER_SIMULATOR_NLU = False
        self.USER_SIMULATOR_NLG = False
        self.USE_NLG = False
        self.USE_SPEECH = False
        self.USER_HAS_INITIATIVE = True
        self.SAVE_LOG = True
        self.SAVE_INTERVAL = 10000

        # The dialogue will terminate after MAX_TURNS (this agent will issue
        # a bye() dialogue act.
        self.MAX_TURNS = 15

        self.dialogue_turn = -1
        self.ontology = None
        self.database = None
        self.domain = None
        self.global_args = {}
        self.dialogue_manager = None
        self.user_model = None
        self.user_simulator = None
        self.user_simulator_args = {}
        self.nlu = None
        self.nlg = None

        self.agent_role = None
        self.agent_goal = None
        self.goal_generator = None

        self.curr_state = None
        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

        self.user_model = UserModel()

        self.recorder = DialogueEpisodeRecorder()

        # TODO: Handle this properly - get reward function type from config
        self.reward_func = SlotFillingReward()
        # self.reward_func = SlotFillingGoalAdvancementReward()

        if self.configuration:
            # Error checks for options the config must have
            if not self.configuration['GENERAL']:
                raise ValueError('Cannot run Plato without GENERAL settings!')

            elif not self.configuration['GENERAL']['interaction_mode']:
                raise ValueError('Cannot run Plato without an '
                                 'interaction mode!')

            elif not self.configuration['DIALOGUE']:
                raise ValueError('Cannot run Plato without DIALOGUE settings!')

            elif not self.configuration['AGENT_0']:
                raise ValueError('Cannot run Plato without at least '
                                 'one agent!')

            # dialogue domain self.settings
            if 'DIALOGUE' in self.configuration and \
                    self.configuration['DIALOGUE']:
                if 'initiative' in self.configuration['DIALOGUE']:
                    self.USER_HAS_INITIATIVE = bool(
                        self.configuration['DIALOGUE']['initiative'] == 'user'
                    )
                    self.user_simulator_args['us_has_initiative'] = \
                        self.USER_HAS_INITIATIVE

                if self.configuration['DIALOGUE']['domain']:
                    self.domain = self.configuration['DIALOGUE']['domain']

                if 'ontology_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['ontology_path']
                    ):
                        self.ontology = ontology.Ontology(
                            self.configuration['DIALOGUE']['ontology_path']
                        )
                    else:
                        raise FileNotFoundError(
                            'domain file %s not found' %
                            self.configuration['DIALOGUE']['ontology_path'])

                # Alternatively, look at global_arguments for ontology path
                elif 'global_arguments' in self.configuration['GENERAL'] \
                        and 'ontology' in \
                        self.configuration['GENERAL']['global_arguments']:
                    if os.path.isfile(
                            self.configuration['GENERAL'][
                                'global_arguments']['ontology']
                    ):
                        self.ontology = ontology.Ontology(
                            self.configuration['GENERAL'][
                                'global_arguments']['ontology']
                        )
                    else:
                        raise FileNotFoundError(
                            'domain file %s not found' %
                            self.configuration['GENERAL'][
                                'global_arguments']['ontology'])

                if 'db_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['db_path']
                    ):
                        if 'db_type' in self.configuration['DIALOGUE']:
                            if self.configuration['DIALOGUE']['db_type'] == \
                                    'sql':
                                self.database = database.SQLDataBase(
                                    self.configuration['DIALOGUE']['db_path']
                                )
                            else:
                                self.database = database.DataBase(
                                    self.configuration['DIALOGUE']['db_path']
                                )
                        else:
                            # Default to SQL
                            self.database = database.SQLDataBase(
                                self.configuration['DIALOGUE']['db_path']
                            )
                    else:
                        raise FileNotFoundError(
                            'Database file %s not found' %
                            self.configuration['DIALOGUE']['db_path']
                        )

                # Alternatively, look at global arguments for db path
                elif 'global_arguments' in self.configuration['GENERAL'] \
                        and 'database' in \
                        self.configuration['GENERAL']['global_arguments']:
                    if os.path.isfile(
                            self.configuration['GENERAL'][
                                'global_arguments']['database']
                    ):
                        self.database = database.DataBase(
                            self.configuration['GENERAL'][
                                'global_arguments']['database']
                        )
                    else:
                        raise FileNotFoundError(
                            'domain file %s not found' %
                            self.configuration['GENERAL'][
                                'global_arguments']['ontology'])

                if 'goals_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['goals_path']
                    ):
                        self.goals_path = \
                            self.configuration['DIALOGUE']['goals_path']
                    else:
                        raise FileNotFoundError(
                            'Goals file %s not found' %
                            self.configuration['DIALOGUE']['goals_path']
                        )

            # General settings
            if 'GENERAL' in self.configuration and \
                    self.configuration['GENERAL']:
                if 'global_arguments' in self.configuration['GENERAL']:
                    self.global_args = \
                        self.configuration['GENERAL']['global_arguments']
                    
                if 'experience_logs' in self.configuration['GENERAL']:
                    dialogues_path = None
                    if 'path' in \
                            self.configuration['GENERAL']['experience_logs']:
                        dialogues_path = \
                            self.configuration['GENERAL'][
                                'experience_logs']['path']

                    if 'load' in \
                            self.configuration['GENERAL']['experience_logs'] \
                            and bool(self.configuration['GENERAL'][
                                         'experience_logs']['load']):
                        if dialogues_path and os.path.isfile(dialogues_path):
                            self.recorder.load(dialogues_path)
                        else:
                            raise FileNotFoundError(
                                'dialogue Log file %s not found (did you '
                                'provide one?)' % dialogues_path)

                    if 'save' in \
                            self.configuration['GENERAL']['experience_logs']:
                        self.recorder.set_path(dialogues_path)
                        self.SAVE_LOG = bool(
                            self.configuration['GENERAL'][
                                'experience_logs']['save']
                        )

                if self.configuration['GENERAL']['interaction_mode'] == \
                        'simulation':
                    self.USE_USR_SIMULATOR = True

                elif self.configuration['GENERAL']['interaction_mode'] == \
                        'speech':
                    self.USE_SPEECH = True
                    self.asr = speech_rec.Recognizer()

            # Agent Settings

            # Retrieve agent role
            if 'role' in self.configuration['AGENT_0']:
                self.agent_role = self.configuration['AGENT_0']['role']
            else:
                raise ValueError(
                    'agent: No role assigned for agent {0} in '
                    'config!'.format(self.agent_id)
                )

            if self.agent_role == 'user':
                if self.ontology and self.database:
                    self.goal_generator = GoalGenerator({
                        'ontology': self.ontology,
                        'database': self.database
                    })
                else:
                    raise ValueError(
                        'Conversational Single Agent (user): Cannot generate '
                        'goal without ontology and database.'
                    )
                
            # Retrieve agent parameters
            if 'max_turns' in self.configuration['AGENT_0']:
                self.MAX_TURNS = self.configuration['AGENT_0']['max_turns']
                
            if 'train_interval' in self.configuration['AGENT_0']:
                self.train_interval = \
                    self.configuration['AGENT_0']['train_interval']
                
            if 'train_minibatch' in self.configuration['AGENT_0']:
                self.minibatch_length = \
                    self.configuration['AGENT_0']['train_minibatch']
                
            if 'train_epochs' in self.configuration['AGENT_0']:
                self.train_epochs = \
                    self.configuration['AGENT_0']['train_epochs']

            if 'save_interval' in self.configuration['AGENT_0']:
                self.SAVE_INTERVAL = \
                    self.configuration['AGENT_0']['save_interval']
            
            # usr Simulator
            # Check for specific simulator self.settings, otherwise
            # default to agenda
            if 'USER_SIMULATOR' in self.configuration['AGENT_0']:
                # Agent 0 simulator configuration
                if 'package' in \
                    self.configuration['AGENT_0']['USER_SIMULATOR'] and \
                        'class' in \
                        self.configuration['AGENT_0']['USER_SIMULATOR']:
                    if 'arguments' in \
                            self.configuration['AGENT_0']['USER_SIMULATOR']:
                        self.user_simulator_args =\
                            self.configuration[
                                'AGENT_0']['USER_SIMULATOR']['arguments']

                    self.user_simulator_args.update(self.global_args)
                    
                    self.user_simulator = \
                        ConversationalGenericAgent.load_module(
                            self.configuration['AGENT_0']['USER_SIMULATOR'][
                                'package'],
                            self.configuration['AGENT_0']['USER_SIMULATOR'][
                                'class'],
                            self.user_simulator_args
                        )

                    if hasattr(self.user_simulator, 'nlu'):
                        self.USER_SIMULATOR_NLU = self.user_simulator.nlu
                    if hasattr(self.user_simulator, 'nlg'):
                        self.USER_SIMULATOR_NLG = self.user_simulator.nlg
                else:
                    # Fallback to agenda based simulator with default settings
                    self.user_simulator = AgendaBasedUS(
                        self.user_simulator_args
                    )

            # NLU Settings
            if 'NLU' in self.configuration['AGENT_0']:
                nlu_args = {}
                if 'package' in self.configuration['AGENT_0']['NLU'] and \
                        'class' in self.configuration['AGENT_0']['NLU']:
                    if 'arguments' in \
                            self.configuration['AGENT_0']['NLU']:
                        nlu_args = \
                            self.configuration['AGENT_0']['NLU']['arguments']

                    nlu_args.update(self.global_args)

                    self.nlu = \
                        ConversationalGenericAgent.load_module(
                            self.configuration['AGENT_0']['NLU'][
                                'package'],
                            self.configuration['AGENT_0']['NLU'][
                                'class'],
                            nlu_args
                        )
                    
            # DM Settings
            if 'DM' in self.configuration['AGENT_0']:
                dm_args = dict(
                    zip(
                        ['settings', 'ontology', 'database', 'domain',
                         'agent_id',
                         'agent_role'],
                        [self.configuration,
                         self.ontology,
                         self.database,
                         self.domain,
                         self.agent_id,
                         self.agent_role
                         ]
                    )
                )

                if 'package' in self.configuration['AGENT_0']['DM'] and \
                        'class' in self.configuration['AGENT_0']['DM']:
                    if 'arguments' in \
                            self.configuration['AGENT_0']['DM']:
                        dm_args.update(
                            self.configuration['AGENT_0']['DM']['arguments']
                        )

                    dm_args.update(self.global_args)

                    self.dialogue_manager = \
                        ConversationalGenericAgent.load_module(
                            self.configuration['AGENT_0']['DM'][
                                'package'],
                            self.configuration['AGENT_0']['DM'][
                                'class'],
                            dm_args
                        )

            # NLG Settings
            if 'NLG' in self.configuration['AGENT_0']:
                nlg_args = {}
                if 'package' in self.configuration['AGENT_0']['NLG'] and \
                        'class' in self.configuration['AGENT_0']['NLG']:
                    if 'arguments' in \
                            self.configuration['AGENT_0']['NLG']:
                        nlg_args = \
                            self.configuration['AGENT_0']['NLG']['arguments']

                    nlg_args.update(self.global_args)

                    self.nlg = \
                        ConversationalGenericAgent.load_module(
                            self.configuration['AGENT_0']['NLG'][
                                'package'],
                            self.configuration['AGENT_0']['NLG'][
                                'class'],
                            nlg_args
                        )

                if self.nlg:
                    self.USE_NLG = True

        # True if at least one module is training
        self.IS_TRAINING = self.nlu and self.nlu.training or \
            self.dialogue_manager and self.dialogue_manager.training or \
            self.nlg and self.nlg.training
예제 #3
0
class ConversationalSingleAgent(ConversationalAgent):
    """
    Essentially the dialogue system. Will be able to interact with:

    - Simulated Users via:
        - dialogue Acts
        - Text

    - Human Users via:
        - Text
        - Speech
        - Online crowd?

    - parser
    """

    def __init__(self, configuration):
        """
        Initialize the internal structures of this agent.

        :param configuration: a dictionary representing the configuration file
        """

        super(ConversationalSingleAgent, self).__init__()

        self.configuration = configuration

        # There is only one agent in this setting
        self.agent_id = 0

        # dialogue statistics
        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0
        self.total_dialogue_turns = 0

        # Default meta-parameter values
        self.minibatch_length = 500
        self.train_interval = 50
        self.train_epochs = 3

        # True values here would imply some default modules
        self.USE_USR_SIMULATOR = False
        self.USER_SIMULATOR_NLU = False
        self.USER_SIMULATOR_NLG = False
        self.USE_NLG = False
        self.USE_SPEECH = False
        self.USER_HAS_INITIATIVE = True
        self.SAVE_LOG = True
        self.SAVE_INTERVAL = 10000

        # The dialogue will terminate after MAX_TURNS (this agent will issue
        # a bye() dialogue act.
        self.MAX_TURNS = 15

        self.dialogue_turn = -1
        self.ontology = None
        self.database = None
        self.domain = None
        self.global_args = {}
        self.dialogue_manager = None
        self.user_model = None
        self.user_simulator = None
        self.user_simulator_args = {}
        self.nlu = None
        self.nlg = None

        self.agent_role = None
        self.agent_goal = None
        self.goal_generator = None

        self.curr_state = None
        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

        self.user_model = UserModel()

        self.recorder = DialogueEpisodeRecorder()

        # TODO: Handle this properly - get reward function type from config
        self.reward_func = SlotFillingReward()
        # self.reward_func = SlotFillingGoalAdvancementReward()

        if self.configuration:
            # Error checks for options the config must have
            if not self.configuration['GENERAL']:
                raise ValueError('Cannot run Plato without GENERAL settings!')

            elif not self.configuration['GENERAL']['interaction_mode']:
                raise ValueError('Cannot run Plato without an '
                                 'interaction mode!')

            elif not self.configuration['DIALOGUE']:
                raise ValueError('Cannot run Plato without DIALOGUE settings!')

            elif not self.configuration['AGENT_0']:
                raise ValueError('Cannot run Plato without at least '
                                 'one agent!')

            # dialogue domain self.settings
            if 'DIALOGUE' in self.configuration and \
                    self.configuration['DIALOGUE']:
                if 'initiative' in self.configuration['DIALOGUE']:
                    self.USER_HAS_INITIATIVE = bool(
                        self.configuration['DIALOGUE']['initiative'] == 'user'
                    )
                    self.user_simulator_args['us_has_initiative'] = \
                        self.USER_HAS_INITIATIVE

                if self.configuration['DIALOGUE']['domain']:
                    self.domain = self.configuration['DIALOGUE']['domain']

                if 'ontology_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['ontology_path']
                    ):
                        self.ontology = ontology.Ontology(
                            self.configuration['DIALOGUE']['ontology_path']
                        )
                    else:
                        raise FileNotFoundError(
                            'domain file %s not found' %
                            self.configuration['DIALOGUE']['ontology_path'])

                # Alternatively, look at global_arguments for ontology path
                elif 'global_arguments' in self.configuration['GENERAL'] \
                        and 'ontology' in \
                        self.configuration['GENERAL']['global_arguments']:
                    if os.path.isfile(
                            self.configuration['GENERAL'][
                                'global_arguments']['ontology']
                    ):
                        self.ontology = ontology.Ontology(
                            self.configuration['GENERAL'][
                                'global_arguments']['ontology']
                        )
                    else:
                        raise FileNotFoundError(
                            'domain file %s not found' %
                            self.configuration['GENERAL'][
                                'global_arguments']['ontology'])

                if 'db_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['db_path']
                    ):
                        if 'db_type' in self.configuration['DIALOGUE']:
                            if self.configuration['DIALOGUE']['db_type'] == \
                                    'sql':
                                self.database = database.SQLDataBase(
                                    self.configuration['DIALOGUE']['db_path']
                                )
                            else:
                                self.database = database.DataBase(
                                    self.configuration['DIALOGUE']['db_path']
                                )
                        else:
                            # Default to SQL
                            self.database = database.SQLDataBase(
                                self.configuration['DIALOGUE']['db_path']
                            )
                    else:
                        raise FileNotFoundError(
                            'Database file %s not found' %
                            self.configuration['DIALOGUE']['db_path']
                        )

                # Alternatively, look at global arguments for db path
                elif 'global_arguments' in self.configuration['GENERAL'] \
                        and 'database' in \
                        self.configuration['GENERAL']['global_arguments']:
                    if os.path.isfile(
                            self.configuration['GENERAL'][
                                'global_arguments']['database']
                    ):
                        self.database = database.DataBase(
                            self.configuration['GENERAL'][
                                'global_arguments']['database']
                        )
                    else:
                        raise FileNotFoundError(
                            'domain file %s not found' %
                            self.configuration['GENERAL'][
                                'global_arguments']['ontology'])

                if 'goals_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['goals_path']
                    ):
                        self.goals_path = \
                            self.configuration['DIALOGUE']['goals_path']
                    else:
                        raise FileNotFoundError(
                            'Goals file %s not found' %
                            self.configuration['DIALOGUE']['goals_path']
                        )

            # General settings
            if 'GENERAL' in self.configuration and \
                    self.configuration['GENERAL']:
                if 'global_arguments' in self.configuration['GENERAL']:
                    self.global_args = \
                        self.configuration['GENERAL']['global_arguments']
                    
                if 'experience_logs' in self.configuration['GENERAL']:
                    dialogues_path = None
                    if 'path' in \
                            self.configuration['GENERAL']['experience_logs']:
                        dialogues_path = \
                            self.configuration['GENERAL'][
                                'experience_logs']['path']

                    if 'load' in \
                            self.configuration['GENERAL']['experience_logs'] \
                            and bool(self.configuration['GENERAL'][
                                         'experience_logs']['load']):
                        if dialogues_path and os.path.isfile(dialogues_path):
                            self.recorder.load(dialogues_path)
                        else:
                            raise FileNotFoundError(
                                'dialogue Log file %s not found (did you '
                                'provide one?)' % dialogues_path)

                    if 'save' in \
                            self.configuration['GENERAL']['experience_logs']:
                        self.recorder.set_path(dialogues_path)
                        self.SAVE_LOG = bool(
                            self.configuration['GENERAL'][
                                'experience_logs']['save']
                        )

                if self.configuration['GENERAL']['interaction_mode'] == \
                        'simulation':
                    self.USE_USR_SIMULATOR = True

                elif self.configuration['GENERAL']['interaction_mode'] == \
                        'speech':
                    self.USE_SPEECH = True
                    self.asr = speech_rec.Recognizer()

            # Agent Settings

            # Retrieve agent role
            if 'role' in self.configuration['AGENT_0']:
                self.agent_role = self.configuration['AGENT_0']['role']
            else:
                raise ValueError(
                    'agent: No role assigned for agent {0} in '
                    'config!'.format(self.agent_id)
                )

            if self.agent_role == 'user':
                if self.ontology and self.database:
                    self.goal_generator = GoalGenerator({
                        'ontology': self.ontology,
                        'database': self.database
                    })
                else:
                    raise ValueError(
                        'Conversational Single Agent (user): Cannot generate '
                        'goal without ontology and database.'
                    )
                
            # Retrieve agent parameters
            if 'max_turns' in self.configuration['AGENT_0']:
                self.MAX_TURNS = self.configuration['AGENT_0']['max_turns']
                
            if 'train_interval' in self.configuration['AGENT_0']:
                self.train_interval = \
                    self.configuration['AGENT_0']['train_interval']
                
            if 'train_minibatch' in self.configuration['AGENT_0']:
                self.minibatch_length = \
                    self.configuration['AGENT_0']['train_minibatch']
                
            if 'train_epochs' in self.configuration['AGENT_0']:
                self.train_epochs = \
                    self.configuration['AGENT_0']['train_epochs']

            if 'save_interval' in self.configuration['AGENT_0']:
                self.SAVE_INTERVAL = \
                    self.configuration['AGENT_0']['save_interval']
            
            # usr Simulator
            # Check for specific simulator self.settings, otherwise
            # default to agenda
            if 'USER_SIMULATOR' in self.configuration['AGENT_0']:
                # Agent 0 simulator configuration
                if 'package' in \
                    self.configuration['AGENT_0']['USER_SIMULATOR'] and \
                        'class' in \
                        self.configuration['AGENT_0']['USER_SIMULATOR']:
                    if 'arguments' in \
                            self.configuration['AGENT_0']['USER_SIMULATOR']:
                        self.user_simulator_args =\
                            self.configuration[
                                'AGENT_0']['USER_SIMULATOR']['arguments']

                    self.user_simulator_args.update(self.global_args)
                    
                    self.user_simulator = \
                        ConversationalGenericAgent.load_module(
                            self.configuration['AGENT_0']['USER_SIMULATOR'][
                                'package'],
                            self.configuration['AGENT_0']['USER_SIMULATOR'][
                                'class'],
                            self.user_simulator_args
                        )

                    if hasattr(self.user_simulator, 'nlu'):
                        self.USER_SIMULATOR_NLU = self.user_simulator.nlu
                    if hasattr(self.user_simulator, 'nlg'):
                        self.USER_SIMULATOR_NLG = self.user_simulator.nlg
                else:
                    # Fallback to agenda based simulator with default settings
                    self.user_simulator = AgendaBasedUS(
                        self.user_simulator_args
                    )

            # NLU Settings
            if 'NLU' in self.configuration['AGENT_0']:
                nlu_args = {}
                if 'package' in self.configuration['AGENT_0']['NLU'] and \
                        'class' in self.configuration['AGENT_0']['NLU']:
                    if 'arguments' in \
                            self.configuration['AGENT_0']['NLU']:
                        nlu_args = \
                            self.configuration['AGENT_0']['NLU']['arguments']

                    nlu_args.update(self.global_args)

                    self.nlu = \
                        ConversationalGenericAgent.load_module(
                            self.configuration['AGENT_0']['NLU'][
                                'package'],
                            self.configuration['AGENT_0']['NLU'][
                                'class'],
                            nlu_args
                        )
                    
            # DM Settings
            if 'DM' in self.configuration['AGENT_0']:
                dm_args = dict(
                    zip(
                        ['settings', 'ontology', 'database', 'domain',
                         'agent_id',
                         'agent_role'],
                        [self.configuration,
                         self.ontology,
                         self.database,
                         self.domain,
                         self.agent_id,
                         self.agent_role
                         ]
                    )
                )

                if 'package' in self.configuration['AGENT_0']['DM'] and \
                        'class' in self.configuration['AGENT_0']['DM']:
                    if 'arguments' in \
                            self.configuration['AGENT_0']['DM']:
                        dm_args.update(
                            self.configuration['AGENT_0']['DM']['arguments']
                        )

                    dm_args.update(self.global_args)

                    self.dialogue_manager = \
                        ConversationalGenericAgent.load_module(
                            self.configuration['AGENT_0']['DM'][
                                'package'],
                            self.configuration['AGENT_0']['DM'][
                                'class'],
                            dm_args
                        )

            # NLG Settings
            if 'NLG' in self.configuration['AGENT_0']:
                nlg_args = {}
                if 'package' in self.configuration['AGENT_0']['NLG'] and \
                        'class' in self.configuration['AGENT_0']['NLG']:
                    if 'arguments' in \
                            self.configuration['AGENT_0']['NLG']:
                        nlg_args = \
                            self.configuration['AGENT_0']['NLG']['arguments']

                    nlg_args.update(self.global_args)

                    self.nlg = \
                        ConversationalGenericAgent.load_module(
                            self.configuration['AGENT_0']['NLG'][
                                'package'],
                            self.configuration['AGENT_0']['NLG'][
                                'class'],
                            nlg_args
                        )

                if self.nlg:
                    self.USE_NLG = True

        # True if at least one module is training
        self.IS_TRAINING = self.nlu and self.nlu.training or \
            self.dialogue_manager and self.dialogue_manager.training or \
            self.nlg and self.nlg.training

    def __del__(self):
        """
        Do some house-keeping and save the models.

        :return: nothing
        """

        if self.recorder and self.SAVE_LOG:
            self.recorder.save()

        if self.nlu:
            self.nlu.save()

        if self.dialogue_manager:
            self.dialogue_manager.save()

        if self.nlg:
            self.nlg.save()

        self.curr_state = None
        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

    def initialize(self):
        """
        Initializes the conversational agent based on settings in the
        configuration file.

        :return: Nothing
        """

        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0

        if self.nlu:
            self.nlu.initialize({})

        self.dialogue_manager.initialize({})

        if self.nlg:
            self.nlg.initialize({})

        self.curr_state = None
        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

    def start_dialogue(self, args=None):
        """
        Perform initial dialogue turn.

        :param args: optional args
        :return:
        """

        self.dialogue_turn = 0
        sys_utterance = ''

        if self.USE_USR_SIMULATOR:
            self.user_simulator.initialize(self.user_simulator_args)

            print('DEBUG > usr goal:')
            print(self.user_simulator.goal)

        self.dialogue_manager.restart({})

        if not self.USER_HAS_INITIATIVE:
            # sys_response = self.dialogue_manager.respond()
            sys_response = [DialogueAct('welcomemsg', [])]

            if self.USE_NLG:
                sys_utterance = self.nlg.generate_output(
                    {'dacts': sys_response}
                )
                print('SYSTEM > %s ' % sys_utterance)

                if self.USE_SPEECH:
                    try:
                        tts = gTTS(sys_utterance)
                        tts.save('sys_output.mp3')
                        os.system('afplay sys_output.mp3')

                    except Exception as e:
                        print(
                            'WARNING: gTTS encountered an error: {0}. '
                            'Falling back to sys TTS.'.format(e)
                        )
                        os.system('say ' + sys_utterance)
            else:
                print(
                    'SYSTEM > %s ' % '; '.
                    join([str(sr) for sr in sys_response])
                )

            if self.USE_USR_SIMULATOR:
                usim_input = sys_response

                if self.USER_SIMULATOR_NLU and self.USE_NLG:
                    usim_input = self.user_simulator.nlu.process_input(
                        sys_utterance
                    )

                self.user_simulator.receive_input(usim_input)
                rew, success, task_success = self.reward_func.calculate(
                    self.dialogue_manager.get_state(),
                    sys_response,
                    self.user_simulator.goal
                )
            else:
                rew, success, task_success = 0, None, None

            self.recorder.record(
                deepcopy(self.dialogue_manager.get_state()),
                self.dialogue_manager.get_state(),
                sys_response,
                rew,
                success,
                task_success,
                output_utterance=sys_utterance
            )

            self.dialogue_turn += 1

        self.prev_state = None

        # Re-initialize these for good measure
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

        self.continue_dialogue()

    def continue_dialogue(self):
        """
        Perform next dialogue turn.

        :return: nothing
        """

        usr_utterance = ''
        sys_utterance = ''

        if self.USE_USR_SIMULATOR:
            usr_input = self.user_simulator.respond()

            # TODO: THIS FIRST IF WILL BE HANDLED BY ConversationalAgentGeneric
            #  -- SHOULD NOT LIVE HERE
            if isinstance(self.user_simulator, DTLUserSimulator):
                print('USER (nlg) > %s \n' % usr_input)
                usr_input = self.nlu.process_input(
                    usr_input,
                    self.dialogue_manager.get_state()
                )

            elif self.USER_SIMULATOR_NLG:
                print('USER > %s \n' % usr_input)

                if self.nlu:
                    usr_input = self.nlu.process_input(usr_input)

                    # Otherwise it will just print the user's nlg but use the
                    # simulator's output DActs to proceed.

            else:
                print('USER (DACT) > %s \n' % '; '.join(
                    [str(ui) for ui in usr_input]))

        else:
            if self.USE_SPEECH:
                # Listen for input from the microphone
                with speech_rec.Microphone() as source:
                    print('(listening...)')
                    audio = self.asr.listen(source, phrase_time_limit=3)

                try:
                    # This uses the default key
                    usr_utterance = self.asr.recognize_google(audio)
                    print("Google ASR: " + usr_utterance)

                except speech_rec.UnknownValueError:
                    print("Google ASR did not understand you")

                except speech_rec.RequestError as e:
                    print("Google ASR request error: {0}".format(e))

            else:
                usr_utterance = input('USER > ')

            # Process the user's utterance
            if self.nlu:
                usr_input = self.nlu.process_input(
                    usr_utterance,
                    self.dialogue_manager.get_state()
                )
            else:
                raise EnvironmentError(
                    'agent: No nlu defined for '
                    'text-based interaction!'
                )

        self.dialogue_manager.receive_input(usr_input)

        # Keep track of prev_state, for the DialogueEpisodeRecorder
        # Store here because this is the state that the dialogue manager
        # will use to make a decision.
        self.curr_state = deepcopy(self.dialogue_manager.get_state())

        if self.dialogue_turn < self.MAX_TURNS:
            sys_response = self.dialogue_manager.generate_output()

        else:
            # Force dialogue stop
            sys_response = [DialogueAct('bye', [])]

        if self.USE_NLG:
            sys_utterance = self.nlg.generate_output({'dacts': sys_response})
            print('SYSTEM > %s ' % sys_utterance)

            if self.USE_SPEECH:
                try:
                    tts = gTTS(text=sys_utterance, lang='en')
                    tts.save('sys_output.mp3')
                    os.system('afplay sys_output.mp3')

                except:
                    print('WARNING: gTTS encountered an error. '
                          'Falling back to sys TTS.')
                    os.system('say ' + sys_utterance)
        else:
            print('SYSTEM > %s ' % '; '.join([str(sr) for sr in sys_response]))

        if self.USE_USR_SIMULATOR:
            usim_input = sys_response

            if self.USER_SIMULATOR_NLU and self.USE_NLG:
                usim_input = \
                    self.user_simulator.nlu.process_input(sys_utterance)

            self.user_simulator.receive_input(usim_input)
            rew, success, task_success = \
                self.reward_func.calculate(
                    self.dialogue_manager.get_state(),
                    sys_response,
                    self.user_simulator.goal
                )
        else:
            rew, success, task_success = 0, None, None

        if self.prev_state:
            self.recorder.record(
                self.prev_state,
                self.curr_state,
                self.prev_action,
                self.prev_reward,
                self.prev_success,
                input_utterance=self.prev_usr_utterance,
                output_utterance=self.prev_sys_utterance,
                task_success=self.prev_task_success
            )

        self.dialogue_turn += 1

        self.prev_state = deepcopy(self.curr_state)
        self.prev_action = deepcopy(sys_response)
        self.prev_usr_utterance = deepcopy(usr_utterance)
        self.prev_sys_utterance = deepcopy(sys_utterance)
        self.prev_reward = rew
        self.prev_success = success
        self.prev_task_success = task_success

    def end_dialogue(self):
        """
        Perform final dialogue turn. Train and save models if applicable.

        :return: nothing
        """

        # Record final state
        self.recorder.record(
            self.curr_state,
            self.curr_state,
            self.prev_action,
            self.prev_reward,
            self.prev_success,
            input_utterance=self.prev_usr_utterance,
            output_utterance=self.prev_sys_utterance,
            task_success=self.prev_task_success,
            force_terminate=True
        )

        self.dialogue_episode += 1

        if self.IS_TRAINING:
            if self.dialogue_episode % self.train_interval == 0 and \
                    len(self.recorder.dialogues) >= self.minibatch_length:
                for epoch in range(self.train_epochs):
                    print('Training epoch {0} of {1}'.format(
                        (epoch+1),
                        self.train_epochs)
                    )

                    # Sample minibatch
                    minibatch = random.sample(
                        self.recorder.dialogues,
                        self.minibatch_length
                    )

                    if self.nlu and self.nlu.training:
                        self.nlu.train(minibatch)

                    if self.dialogue_manager.is_training():
                        self.dialogue_manager.train(minibatch)

                    if self.nlg and self.nlg.training:
                        self.nlg.train(minibatch)

        # Keep track of dialogue statistics
        self.cumulative_rewards += \
            self.recorder.dialogues[-1][-1]['cumulative_reward']
        print('CUMULATIVE REWARD: {0}'.
              format(self.recorder.dialogues[-1][-1]['cumulative_reward']))

        if self.dialogue_turn > 0:
            self.total_dialogue_turns += self.dialogue_turn

        if self.dialogue_episode % self.SAVE_INTERVAL == 0:
            if self.nlu:
                self.nlu.save()

            if self.dialogue_manager:
                self.dialogue_manager.save()

            if self.nlg:
                self.nlg.save()

        # Count successful dialogues
        if self.recorder.dialogues[-1][-1]['success']:
            print('SUCCESS (Subjective)!')
            self.num_successful_dialogues += \
                int(self.recorder.dialogues[-1][-1]['success'])

        else:
            print('FAILURE (Subjective).')

        if self.recorder.dialogues[-1][-1]['task_success']:
            self.num_task_success += \
                int(self.recorder.dialogues[-1][-1]['task_success'])

        print('OBJECTIVE TASK SUCCESS: {0}'.
              format(self.recorder.dialogues[-1][-1]['task_success']))

    def terminated(self):
        """
        Check if this agent is at a terminal state.

        :return: True or False
        """

        return self.dialogue_manager.at_terminal_state() or \
            self.dialogue_turn > self.MAX_TURNS
class SupervisedPolicy(dialogue_policy.DialoguePolicy):
    def __init__(self, args):
        """
        Initialize parameters and internal structures

        :param args: dictionary containing the policy's arguments
        """
        super(SupervisedPolicy, self).__init__()

        self.ontology = None
        if 'ontology' in args:
            ontology = args['ontology']

            if isinstance(ontology, Ontology):
                self.ontology = ontology
            else:
                raise ValueError('SupervisedPolicy Unacceptable '
                                 'ontology type %s ' % ontology)
        else:
            raise ValueError('SupervisedPolicy: No ontology provided')

        self.database = None
        if 'database' in args:
            database = args['database']

            if isinstance(database, DataBase):
                self.database = database
            else:
                raise ValueError('SupervisedPolicy: Unacceptable '
                                 'database type %s ' % database)
        else:
            raise ValueError('SupervisedPolicy: No database provided')

        self.agent_id = args['agent_id'] if 'agent_id' in args else 0
        self.agent_role = args['agent_role'] \
            if 'agent_role' in args else 'system'
        domain = args['domain'] if 'domain' in args else None

        # True for greedy, False for stochastic
        self.IS_GREEDY_POLICY = False

        self.policy_path = None

        self.policy_net = None
        self.tf_scope = "policy_" + self.agent_role + '_' + str(self.agent_id)
        self.sess = None

        # The system and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        # Default value
        self.is_training = True

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'] +
                     ['this', 'signature'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        self.dstc2_acts = None

        if not domain:
            # Default to CamRest dimensions
            self.NStateFeatures = 56

            # Default to CamRest actions
            self.dstc2_acts = [
                'repeat', 'canthelp', 'affirm', 'negate', 'deny', 'ack',
                'thankyou', 'bye', 'reqmore', 'hello', 'welcomemsg',
                'expl-conf', 'select', 'offer', 'reqalts', 'confirm-domain',
                'confirm'
            ]
        else:
            # Try to identify number of state features
            if domain in ['SlotFilling', 'CamRest']:
                d_state = \
                    SlotFillingDialogueState(
                        {'slots': self.system_requestable_slots})

                # Plato does not use action masks (rules to define which
                # actions are valid from each state) and so training can
                # be harder. This becomes easier if we have a smaller
                # action set.

                # Sub-case for CamRest
                if domain == 'CamRest':
                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_sys = [
                        'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye',
                        'reqmore', 'welcomemsg', 'expl-conf', 'select',
                        'repeat', 'confirm-domain', 'confirm'
                    ]

                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_usr = [
                        'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye',
                        'reqmore', 'hello', 'expl-conf', 'repeat', 'reqalts',
                        'restart', 'confirm'
                    ]

                    if self.agent_role == 'system':
                        self.dstc2_acts = self.dstc2_acts_sys

                    elif self.agent_role == 'user':
                        self.dstc2_acts = self.dstc2_acts_usr

            else:
                print('Warning! domain has not been defined. Using '
                      'Slot-Filling dialogue State')
                d_state = \
                    SlotFillingDialogueState({'slots': self.informable_slots})

            d_state.initialize()
            self.NStateFeatures = len(self.encode_state(d_state))
            print('Supervised dialogue policy automatically determined number '
                  'of state features: {0}'.format(self.NStateFeatures))

        if domain == 'CamRest':
            self.NActions = len(self.dstc2_acts) + len(self.requestable_slots)

            if self.agent_role == 'system':
                self.NActions += len(self.system_requestable_slots)
            else:
                self.NActions += len(self.requestable_slots)
        else:
            self.NActions = 5

        self.policy_alpha = 0.05

        self.tf_saver = None

    def initialize(self, args):
        """
        Initialize internal structures at the beginning of each dialogue

        :return: Nothing
        """

        if self.agent_role == 'system':
            # Put your system expert dialogue policy here
            self.warmup_policy = HandcraftedPolicy({'ontology': self.ontology})

        elif self.agent_role == 'user':
            usim_args = \
                dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))
            # Put your user expert dialogue policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        if 'is_training' in args:
            self.is_training = bool(args['is_training'])

            if self.agent_role == 'user' and self.warmup_simulator:
                if 'goal' in args:
                    self.warmup_simulator.initialize({args['goal']})
                else:
                    print('WARNING ! No goal provided for Supervised policy '
                          'user simulator @ initialize')
                    self.warmup_simulator.initialize({})

        if 'policy_path' in args:
            self.policy_path = args['policy_path']

        if 'learning_rate' in args:
            self.policy_alpha = args['learning_rate']

        if self.sess is None:
            self.policy_net = self.feed_forward_net_init()
            self.sess = tf.InteractiveSession()
            self.sess.run(tf.global_variables_initializer())

            self.tf_saver = \
                tf.train.Saver(var_list=tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES, scope=self.tf_scope))

    def restart(self, args):
        """
        Re-initialize relevant parameters / variables at the beginning of each
        dialogue.

        :return:
        """

        if self.agent_role == 'user' and self.warmup_simulator:
            if 'goal' in args:
                self.warmup_simulator.initialize(args)
            else:
                print('WARNING! No goal provided for Supervised policy user '
                      'simulator @ restart')
                self.warmup_simulator.initialize({})

    def next_action(self, state):
        """
        Consults the dialogue policy to produce the agent's response

        :param state: the current dialogue state
        :return: a list of dialogue acts, representing the agent's response
        """

        if self.is_training:
            # This is a Supervised dialogue policy, so no exploration here.

            if self.agent_role == 'system':
                return self.warmup_policy.next_action(state)
            else:
                self.warmup_simulator.receive_input(state.user_acts,
                                                    state.user_goal)
                return self.warmup_simulator.generate_output()

        pl_calculated, pl_state, pl_newvals, pl_optimizer, pl_loss = \
            self.policy_net

        obs_vector = np.expand_dims(self.encode_state(state), axis=0)

        probs = self.sess.run(pl_calculated, feed_dict={pl_state: obs_vector})

        if self.IS_GREEDY_POLICY:
            # Greedy policy: Return action with maximum value from the given
            # state
            sys_acts = \
                self.decode_action(
                    np.argmax(probs), self.agent_role == 'system')

        else:
            # Stochastic dialogue policy: Sample action wrt Q values
            if any(np.isnan(probs[0])):
                print('WARNING! Supervised dialogue policy: NAN detected in a'
                      'ction probabilities! Selecting random action.')
                return self.decode_action(
                    random.choice(range(0, self.NActions)),
                    self.agent_role == 'system')

            # Make sure weights are positive
            min_p = min(probs[0])

            if min_p < 0:
                positive_weights = [p + abs(min_p) for p in probs[0]]
            else:
                positive_weights = probs[0]

            # Normalize weights
            positive_weights /= sum(positive_weights)

            sys_acts = \
                self.decode_action(
                    random.choices(
                        [a for a in range(self.NActions)],
                        weights=positive_weights)[0],
                    self.agent_role == 'system')

        return sys_acts

    def feed_forward_net_init(self):
        """
        Initialize the feed forward network.

        :return: some useful variables
        """
        self.tf_scope = "policy_" + self.agent_role + '_' + str(self.agent_id)

        with tf.variable_scope(self.tf_scope):
            state = tf.placeholder("float", [None, self.NStateFeatures])
            newvals = tf.placeholder("float", [None, self.NActions])

            w1 = \
                tf.get_variable("w1",
                                [self.NStateFeatures, self.NStateFeatures])
            b1 = tf.get_variable("b1", [self.NStateFeatures])
            h1 = tf.nn.sigmoid(tf.matmul(state, w1) + b1)

            w2 = \
                tf.get_variable("w2",
                                [self.NStateFeatures, self.NStateFeatures])
            b2 = tf.get_variable("b2", [self.NStateFeatures])
            h2 = tf.nn.sigmoid(tf.matmul(h1, w2) + b2)

            w3 = tf.get_variable("w3", [self.NStateFeatures, self.NActions])
            b3 = tf.get_variable("b3", [self.NActions])

            calculated = tf.nn.softmax(tf.matmul(h2, w3) + b3)

            diffs = calculated - newvals
            loss = tf.nn.l2_loss(diffs)
            optimizer = \
                tf.train.AdamOptimizer(self.policy_alpha).minimize(loss)

            return calculated, state, newvals, optimizer, loss

    def train(self, dialogues):
        """
        Train the neural net dialogue policy model

        :param dialogues: dialogue experience
        :return: nothing
        """

        # If called by accident
        if not self.is_training:
            return

        pl_calculated, pl_state, pl_newvals, pl_optimizer, pl_loss =\
            self.policy_net

        states = []
        actions = []

        for dialogue in dialogues:
            for index, turn in enumerate(dialogue):
                act_enc = \
                    self.encode_action(turn['action'],
                                       self.agent_role == 'system')
                if act_enc > -1:
                    states.append(self.encode_state(turn['state']))
                    action = np.zeros(self.NActions)
                    action[act_enc] = 1
                    actions.append(action)

        # Train dialogue policy
        self.sess.run(pl_optimizer,
                      feed_dict={
                          pl_state: states,
                          pl_newvals: actions
                      })

    def encode_state(self, state):
        """
        Encodes the dialogue state into a vector.

        :param state: the state to encode
        :return: int - a unique state encoding
        """

        temp = [int(state.is_terminal_state)]

        temp.append(1) if state.system_made_offer else temp.append(0)

        # If the agent plays the role of the user it needs access to its own
        # goal
        if self.agent_role == 'user':
            # The user agent needs to know which constraints and requests
            # need to be communicated and which of them
            # actually have.
            if state.user_goal:
                for c in self.informable_slots:
                    if c != 'name':
                        if c in state.user_goal.constraints:
                            temp.append(1)
                        else:
                            temp.append(0)

                for c in self.informable_slots:
                    if c != 'name':
                        if c in state.user_goal.actual_constraints and \
                                state.user_goal.actual_constraints[c].value:
                            temp.append(1)
                        else:
                            temp.append(0)

                for r in self.requestable_slots:
                    if r in state.user_goal.requests:
                        temp.append(1)
                    else:
                        temp.append(0)

                for r in self.requestable_slots:
                    if r in state.user_goal.actual_requests and \
                            state.user_goal.actual_requests[r].value:
                        temp.append(1)
                    else:
                        temp.append(0)
            else:
                temp += [0] * 2 * (len(self.informable_slots) - 1 +
                                   len(self.requestable_slots))

        if self.agent_role == 'system':
            for value in state.slots_filled.values():
                # This contains the requested slot
                temp.append(1) if value else temp.append(0)

            for r in self.requestable_slots:
                temp.append(1) if r == state.requested_slot else temp.append(0)

        return temp

    def encode_action(self, actions, system=True):
        """
        Encode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be encoding another agent's action
        (e.g. a system encoding the previous user act).

        :param actions: actions to be encoded
        :param system: whether the role whose action we are encoding is a
                       'system'
        :return: the encoded action
        """

        if not actions:
            print('WARNING: Supervised dialogue policy action encoding called '
                  'with empty actions list (returning -1).')
            return -1

        action = actions[0]

        slot = None
        if action.params and action.params[0].slot:
            slot = action.params[0].slot

        if system:
            if self.dstc2_acts_sys and action.intent in self.dstc2_acts_sys:
                return self.dstc2_acts_sys.index(action.intent)

            if slot:
                if action.intent == 'request' and \
                        slot in self.system_requestable_slots:
                    return len(self.dstc2_acts_sys) + \
                           self.system_requestable_slots.index(slot)

                if action.intent == 'inform' and \
                        slot in self.requestable_slots:
                    return len(self.dstc2_acts_sys) + \
                           len(self.system_requestable_slots) + \
                           self.requestable_slots.index(slot)
        else:
            if self.dstc2_acts_usr and action.intent in self.dstc2_acts_usr:
                return self.dstc2_acts_usr.index(action.intent)

            if slot:
                if action.intent == 'request' and \
                        slot in self.requestable_slots:
                    return len(self.dstc2_acts_usr) + \
                           self.requestable_slots.index(slot)

                if action.intent == 'inform' and \
                        slot in self.requestable_slots:
                    return len(self.dstc2_acts_usr) + \
                           len(self.requestable_slots) + \
                           self.requestable_slots.index(slot)

        # Default fall-back action
        print('Supervised ({0}) policy action encoder warning: Selecting '
              'default action (unable to encode: {1})!'.format(
                  self.agent_role, action))
        return -1

    def decode_action(self, action_enc, system=True):
        """
        Decode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be decoding another agent's action
        (e.g. a system decoding the previous user act).

        :param action_enc: action encoding to be decoded
        :param system: whether the role whose action we are decoding is a
                       'system'
        :return: the decoded action
        """

        if system:
            if action_enc < len(self.dstc2_acts_sys):
                return [DialogueAct(self.dstc2_acts_sys[action_enc], [])]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.system_requestable_slots[
                                action_enc - len(self.dstc2_acts_sys)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots) +\
                    len(self.requestable_slots):
                index = action_enc - \
                        len(self.dstc2_acts_sys) - \
                        len(self.system_requestable_slots)
                return [
                    DialogueAct('inform', [
                        DialogueActItem(self.requestable_slots[index],
                                        Operator.EQ, '')
                    ])
                ]

        else:
            if action_enc < len(self.dstc2_acts_usr):
                return [DialogueAct(self.dstc2_acts_usr[action_enc], [])]

            if action_enc < len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.requestable_slots[action_enc -
                                                   len(self.dstc2_acts_usr)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots):
                return [
                    DialogueAct('inform', [
                        DialogueActItem(
                            self.requestable_slots[
                                action_enc - len(self.dstc2_acts_usr) -
                                len(self.requestable_slots)], Operator.EQ, '')
                    ])
                ]

    def save(self, path=None):
        """
        Saves the policy model to the provided path

        :param path: path to save the model to
        :return:
        """

        # Don't save if not training
        if not self.is_training:
            return

        pol_path = path

        if not pol_path:
            pol_path = self.policy_path

        if not pol_path:
            pol_path = 'models/policies/supervised_policy_' + \
                       self.agent_role + '_' + str(self.agent_id)

        # If the directory does not exist, create it
        if not os.path.exists(os.path.dirname(pol_path)):
            os.makedirs(os.path.dirname(pol_path), exist_ok=True)

        if self.sess is not None and self.is_training:
            save_path = self.tf_saver.save(self.sess, pol_path)
            print('Supervised policy model saved at: %s' % save_path)

    def load(self, path):
        """
        Load the policy model from the provided path

        :param path: path to load the model from
        :return:
        """

        pol_path = path

        if not pol_path:
            pol_path = self.policy_path

        if not pol_path:
            pol_path = 'models/policies/supervised_policy_' + \
                       self.agent_role + '_' + str(self.agent_id)

        if os.path.isfile(pol_path + '.meta'):
            self.policy_net = self.feed_forward_net_init()
            self.sess = tf.InteractiveSession()

            self.tf_saver = \
                tf.train.Saver(
                    var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                               scope=self.tf_scope))

            self.tf_saver.restore(self.sess, pol_path)

            print('Supervised policy model loaded from {0}.'.format(pol_path))

        else:
            print('WARNING! Supervised policy cannot load policy '
                  'model from {0}!'.format(pol_path))
예제 #5
0
class ReinforcePolicy(dialogue_policy.DialoguePolicy):
    def __init__(self, args):
        """
        Initialize parameters and internal structures

        :param args: the policy's arguments
        """

        super(ReinforcePolicy, self).__init__()

        self.ontology = None
        if 'ontology' in args:
            ontology = args['ontology']

            if isinstance(ontology, Ontology):
                self.ontology = ontology
            else:
                raise ValueError('ReinforcePolicy Unacceptable '
                                 'ontology type %s ' % ontology)
        else:
            raise ValueError('ReinforcePolicy: No ontology provided')

        self.database = None
        if 'database' in args:
            database = args['database']

            if isinstance(database, DataBase):
                self.database = database
            else:
                raise ValueError('ReinforcePolicy: Unacceptable '
                                 'database type %s ' % database)
        else:
            raise ValueError('ReinforcePolicy: No database provided')

        self.agent_id = args['agent_id'] if 'agent_id' in args else 0
        self.agent_role = \
            args['agent_role'] if 'agent_role' in args else 'system'

        domain = args['domain'] if 'domain' in args else None
        self.alpha = args['alpha'] if 'alpha' in args else 0.2
        self.gamma = args['gamma'] if 'gamma' in args else 0.95
        self.epsilon = args['epsilon'] if 'epsilon' in args else 0.95
        self.alpha_decay_rate = \
            args['alpha_decay'] if 'alpha_decay' in args else 0.995
        self.exploration_decay_rate = \
            args['epsilon_decay'] if 'epsilon_decay' in args else 0.9995

        self.IS_GREEDY = False

        self.policy_path = None

        self.weights = None
        self.sess = None

        # System and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        if self.agent_role == 'system':
            # Put your system expert policy here
            self.warmup_policy = HandcraftedPolicy({'ontology': self.ontology})

        elif self.agent_role == 'user':
            usim_args = \
                dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))
            # Put your user expert policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        self.tf_scope = "policy_" + self.agent_role + '_' + str(self.agent_id)

        # Default value
        self.is_training = True

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        if not domain:
            # Default to CamRest dimensions
            self.NStateFeatures = 56

            # Default to CamRest actions
            self.dstc2_acts = [
                'inform', 'offer', 'request', 'canthelp', 'affirm', 'negate',
                'deny', 'ack', 'thankyou', 'bye', 'reqmore', 'hello',
                'welcomemsg', 'expl-conf', 'select', 'repeat', 'reqalts',
                'confirm-domain', 'confirm'
            ]
        else:
            # Try to identify number of state features
            if domain in ['CamRest', 'SFH', 'SlotFilling']:
                d_state = \
                    SlotFillingDialogueState(
                        {'slots': self.system_requestable_slots})

                # Plato does not use action masks (rules to define which
                # actions are valid from each state) and so training can
                # be harder. This becomes easier if we have a smaller
                # action set.

                # Sub-case for CamRest
                if domain == 'CamRest':
                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_sys = [
                        'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye',
                        'reqmore', 'welcomemsg', 'expl-conf', 'select',
                        'repeat', 'confirm-domain', 'confirm'
                    ]

                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_usr = [
                        'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye',
                        'reqmore', 'hello', 'expl-conf', 'repeat', 'reqalts',
                        'restart', 'confirm'
                    ]

            else:
                print('Warning! domain has not been defined. Using '
                      'Slot-Filling dialogue State')
                d_state = \
                    SlotFillingDialogueState({'slots': self.informable_slots})

            d_state.initialize()
            self.NStateFeatures = len(self.encode_state(d_state))

            print('Reinforce policy {0} automatically determined '
                  'number of state features: {1}'.format(
                      self.agent_role, self.NStateFeatures))

        if domain == 'CamRest' and self.dstc2_acts_sys:
            if self.agent_role == 'system':
                self.NActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots)

            elif self.agent_role == 'user':
                self.NActions = \
                    len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

        else:
            if self.agent_role == 'system':
                self.NActions = \
                    3 + len(self.system_requestable_slots) + \
                    len(self.requestable_slots)

                self.NOtherActions = \
                    2 + len(self.requestable_slots) +\
                    len(self.requestable_slots)

            elif self.agent_role == 'user':
                self.NActions = \
                    2 + len(self.requestable_slots) + \
                    len(self.requestable_slots)

                self.NOtherActions = \
                    3 + len(self.system_requestable_slots) + \
                    len(self.requestable_slots)

        print('Reinforce {0} policy Number of Actions: {1}'.format(
            self.agent_role, self.NActions))

    def initialize(self, args):
        """
        Initialize internal structures at the beginning of each dialogue

        :return: Nothing
        """

        if 'is_training' in args:
            self.is_training = bool(args['is_training'])

            if self.agent_role == 'user' and self.warmup_simulator:
                if 'goal' in args:
                    self.warmup_simulator.initialize({args['goal']})
                else:
                    print('WARNING ! No goal provided for Reinforce policy '
                          'user simulator @ initialize')
                    self.warmup_simulator.initialize({})

        if 'policy_path' in args:
            self.policy_path = args['policy_path']

        if 'learning_rate' in args:
            self.alpha = args['learning_rate']

        if 'learning_decay_rate' in args:
            self.alpha_decay_rate = args['learning_decay_rate']

        if 'discount_factor' in args:
            self.gamma = args['discount_factor']

        if 'exploration_rate' in args:
            self.alpha = args['exploration_rate']

        if 'exploration_decay_rate' in args:
            self.exploration_decay_rate = args['exploration_decay_rate']

        if self.weights is None:
            self.weights = np.random.rand(self.NStateFeatures, self.NActions)

    def restart(self, args):
        """
        Re-initialize relevant parameters / variables at the beginning of each
        dialogue.

        :return: nothing
        """

        if self.agent_role == 'user' and self.warmup_simulator:
            if 'goal' in args:
                self.warmup_simulator.initialize(args)
            else:
                print('WARNING! No goal provided for Reinforce '
                      'policy user simulator @ restart')
                self.warmup_simulator.initialize({})

    def next_action(self, state):
        """
        Consults the policy to produce the agent's response

        :param state: the current dialogue state
        :return: a list of dialogue acts, representing the agent's response
        """

        if self.is_training and random.random() < self.epsilon:
            if random.random() < 0.5:
                print('--- {0}: Selecting warmup action.'.format(
                    self.agent_role))

                if self.agent_role == 'system':
                    return self.warmup_policy.next_action(state)

                else:
                    self.warmup_simulator.receive_input(
                        state.user_acts, state.user_goal)
                    return self.warmup_simulator.respond()

            else:
                print('--- {0}: Selecting random action.'.format(
                    self.agent_role))
                return self.decode_action(
                    random.choice(range(0, self.NActions)),
                    self.agent_role == "system")

        # Probabilistic policy: Sample from action wrt probabilities
        probs = self.calculate_policy(self.encode_state(state))

        if any(np.isnan(probs)):
            print('WARNING! NAN detected in action probabilities! Selecting '
                  'random action.')
            return self.decode_action(random.choice(range(0, self.NActions)),
                                      self.agent_role == "system")

        if self.IS_GREEDY:
            # Get greedy action
            max_pi = max(probs)
            maxima = [i for i, j in enumerate(probs) if j == max_pi]

            # Break ties randomly
            if maxima:
                sys_acts = \
                    self.decode_action(
                        random.choice(maxima), self.agent_role == 'system')
            else:
                print(f'--- {self.agent_role}: Warning! No maximum value '
                      f'identified for policy. Selecting random action.')
                return self.decode_action(
                    random.choice(range(0, self.NActions)),
                    self.agent_role == 'system')
        else:
            # Pick from top 3 actions
            top_3 = np.argsort(-probs)[0:2]
            sys_acts = \
                self.decode_action(
                    random.choices(
                        top_3, probs[top_3])[0], self.agent_role == 'system')

        return sys_acts

    @staticmethod
    def softmax(x):
        """
        Calculates the softmax of x

        :param x: a number
        :return: the softmax of the number
        """
        e_x = np.exp(x - np.max(x))
        out = e_x / e_x.sum()
        return out

    @staticmethod
    def softmax_gradient(x):
        """
        Calculates the gradient of the softmax

        :param x: a number
        :return: the gradient of the softmax
        """
        x = np.asarray(x)
        x_reshaped = x.reshape(-1, 1)
        return np.diagflat(x_reshaped) - np.dot(x_reshaped, x_reshaped.T)

    def calculate_policy(self, state):
        """
        Calculates the probabilities for each action from the given state

        :param state: the current dialogue state
        :return: probabilities of actions
        """
        dot_prod = np.dot(state, self.weights)
        exp_dot_prod = np.exp(dot_prod)
        return exp_dot_prod / np.sum(exp_dot_prod)

    def train(self, dialogues):
        """
        Train the policy network

        :param dialogues: dialogue experience
        :return: nothing
        """
        # If called by accident
        if not self.is_training:
            return

        for dialogue in dialogues:
            discount = self.gamma

            if len(dialogue) > 1:
                dialogue[-2]['reward'] = dialogue[-1]['reward']

            rewards = [t['reward'] for t in dialogue]
            norm_rewards = \
                (rewards - np.mean(rewards)) / (np.std(rewards) + 0.000001)

            for (t, turn) in enumerate(dialogue):
                act_enc = self.encode_action(turn['action'],
                                             self.agent_role == 'system')
                if act_enc < 0:
                    continue

                state_enc = self.encode_state(turn['state'])

                if len(state_enc) != self.NStateFeatures:
                    raise ValueError(f'Reinforce dialogue policy '
                                     f'{self.agent_role} mismatch in state'
                                     f'dimensions: State Features: '
                                     f'{self.NStateFeatures} != State '
                                     f'Encoding Length: {len(state_enc)}')

                # Calculate the gradients

                # Call policy again to retrieve the probability of the
                # action taken
                probabilities = self.calculate_policy(state_enc)

                softmax_deriv = self.softmax_gradient(probabilities)[act_enc]
                log_policy_grad = softmax_deriv / probabilities[act_enc]
                gradient = \
                    np.asarray(
                        state_enc)[None, :].transpose().dot(
                        log_policy_grad[None, :])
                gradient = np.clip(gradient, -1.0, 1.0)

                # Train policy
                self.weights += \
                    self.alpha * gradient * norm_rewards[t] * discount
                self.weights = np.clip(self.weights, -1, 1)

                discount *= self.gamma

        if self.alpha > 0.01:
            self.alpha *= self.alpha_decay_rate

        if self.epsilon > 0.5:
            self.epsilon *= self.exploration_decay_rate

        print(f'REINFORCE train, alpha: {self.alpha}, epsilon: {self.epsilon}')

    def encode_state(self, state):
        """
        Encodes the dialogue state into a vector.

        :param state: the state to encode
        :return: int - a unique state encoding
        """

        temp = [int(state.is_terminal_state), int(state.system_made_offer)]

        if self.agent_role == 'user':
            # The user agent needs to know which constraints and requests
            # need to be communicated and which of them
            # actually have.
            if state.user_goal:
                for c in self.informable_slots:
                    if c != 'name':
                        if c in state.user_goal.constraints:
                            temp.append(1)
                        else:
                            temp.append(0)

                for c in self.informable_slots:
                    if c != 'name':
                        if c in state.user_goal.actual_constraints and \
                                state.user_goal.actual_constraints[c].value:
                            temp.append(1)
                        else:
                            temp.append(0)

                for r in self.requestable_slots:
                    if r in state.user_goal.requests:
                        temp.append(1)
                    else:
                        temp.append(0)

                for r in self.requestable_slots:

                    if r in state.user_goal.actual_requests and \
                            state.user_goal.actual_requests[r].value:
                        temp.append(1)
                    else:
                        temp.append(0)

            else:
                temp += [0] * 2 * (len(self.informable_slots) - 1 +
                                   len(self.requestable_slots))

        if self.agent_role == 'system':
            for value in state.slots_filled.values():
                # This contains the requested slot
                temp.append(1) if value else temp.append(0)

            for r in self.requestable_slots:
                temp.append(1) if r == state.requested_slot else temp.append(0)

        return temp

    def encode_action(self, actions, system=True):
        """
        Encode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be encoding another agent's action
        (e.g. a system encoding the previous user act).

        :param actions: actions to be encoded
        :param system: whether the role whose action we are encoding is a
                       'system'
        :return: the encoded action
        """

        # TODO: Handle multiple actions
        if not actions:
            print('WARNING: Reinforce dialogue policy action encoding called '
                  'with empty actions list (returning 0).')
            return -1

        action = actions[0]

        if system:
            if self.dstc2_acts_sys and action.intent in self.dstc2_acts_sys:
                return self.dstc2_acts_sys.index(action.intent)

            if action.intent == 'request':
                return len(self.dstc2_acts_sys) + \
                       self.system_requestable_slots.index(
                           action.params[0].slot)

            if action.intent == 'inform':
                return len(self.dstc2_acts_sys) + \
                       len(self.system_requestable_slots) + \
                       self.requestable_slots.index(action.params[0].slot)
        else:
            if self.dstc2_acts_usr and action.intent in self.dstc2_acts_usr:
                return self.dstc2_acts_usr.index(action.intent)

            if action.intent == 'request':
                return len(self.dstc2_acts_usr) + \
                       self.requestable_slots.index(action.params[0].slot)

            if action.intent == 'inform':
                return len(self.dstc2_acts_usr) + \
                       len(self.requestable_slots) + \
                       self.requestable_slots.index(action.params[0].slot)

        # Default fall-back action
        print('Reinforce ({0}) olicy action encoder warning: Selecting '
              'default action (unable to encode: {1})!'.format(
                  self.agent_role, action))
        return -1

    def decode_action(self, action_enc, system=True):
        """
        Decode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be decoding another agent's action
        (e.g. a system decoding the previous user act).

        :param action_enc: action encoding to be decoded
        :param system: whether the role whose action we are decoding is a
                       'system'
        :return: the decoded action
        """

        if system:
            if action_enc < len(self.dstc2_acts_sys):
                return [DialogueAct(self.dstc2_acts_sys[action_enc], [])]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.system_requestable_slots[
                                action_enc - len(self.dstc2_acts_sys)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots) + \
                    len(self.requestable_slots):
                index = action_enc - len(self.dstc2_acts_sys) - \
                        len(self.system_requestable_slots)
                return [
                    DialogueAct('inform', [
                        DialogueActItem(self.requestable_slots[index],
                                        Operator.EQ, '')
                    ])
                ]

        else:
            if action_enc < len(self.dstc2_acts_usr):
                return [DialogueAct(self.dstc2_acts_usr[action_enc], [])]

            if action_enc < len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.requestable_slots[action_enc -
                                                   len(self.dstc2_acts_usr)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots):
                return [
                    DialogueAct('inform', [
                        DialogueActItem(
                            self.requestable_slots[
                                action_enc - len(self.dstc2_acts_usr) -
                                len(self.requestable_slots)], Operator.EQ, '')
                    ])
                ]

        # Default fall-back action
        print('Reinforce dialogue policy ({0}) policy action decoder warning: '
              'Selecting default action (index: {1})!'.format(
                  self.agent_role, action_enc))
        return [DialogueAct('bye', [])]

    def save(self, path=None):
        """
        Saves the policy model to the provided path

        :param path: path to save the model to
        :return:
        """

        # Don't save if not training
        if not self.is_training:
            return

        if not path:
            path = 'models/policies/reinforce.pkl'
            print('No policy file name provided. Using default: {0}'.format(
                path))

        # If the directory does not exist, create it
        if not os.path.exists(os.path.dirname(path)):
            os.makedirs(os.path.dirname(path), exist_ok=True)

        obj = {
            'weights': self.weights,
            'alpha': self.alpha,
            'alpha_decay_rate': self.alpha_decay_rate,
            'epsilon': self.epsilon,
            'exploration_decay_rate': self.exploration_decay_rate
        }

        with open(path, 'wb') as file:
            pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)

    def load(self, path=None):
        """
        Load the policy model from the provided path

        :param path: path to load the model from
        :return:
        """

        if not path:
            print('No dialogue policy loaded.')
            return

        if isinstance(path, str):
            if os.path.isfile(path):
                with open(path, 'rb') as file:
                    obj = pickle.load(file)

                    if 'weights' in obj:
                        self.weights = obj['weights']

                    if 'alpha' in obj:
                        self.alpha = obj['alpha']

                    if 'alpha_decay_rate' in obj:
                        self.alpha_decay_rate = obj['alpha_decay_rate']

                    if 'epsilon' in obj:
                        self.epsilon = obj['epsilon']

                    if 'exploration_decay_rate' in obj:
                        self.exploration_decay_rate = \
                            obj['exploration_decay_rate']

                    print('Reinforce policy loaded from {0}.'.format(path))

            else:
                print('Warning! Reinforce policy file %s not found' % path)
        else:
            print('Warning! Unacceptable value for Reinforce policy '
                  'file name: %s ' % path)
예제 #6
0
    def __init__(self, args):
        """
        Initialize parameters and internal structures

        :param args: the policy's arguments
        """

        super(ReinforcePolicy, self).__init__()

        self.ontology = None
        if 'ontology' in args:
            ontology = args['ontology']

            if isinstance(ontology, Ontology):
                self.ontology = ontology
            else:
                raise ValueError('ReinforcePolicy Unacceptable '
                                 'ontology type %s ' % ontology)
        else:
            raise ValueError('ReinforcePolicy: No ontology provided')

        self.database = None
        if 'database' in args:
            database = args['database']

            if isinstance(database, DataBase):
                self.database = database
            else:
                raise ValueError('ReinforcePolicy: Unacceptable '
                                 'database type %s ' % database)
        else:
            raise ValueError('ReinforcePolicy: No database provided')

        self.agent_id = args['agent_id'] if 'agent_id' in args else 0
        self.agent_role = \
            args['agent_role'] if 'agent_role' in args else 'system'

        domain = args['domain'] if 'domain' in args else None
        self.alpha = args['alpha'] if 'alpha' in args else 0.2
        self.gamma = args['gamma'] if 'gamma' in args else 0.95
        self.epsilon = args['epsilon'] if 'epsilon' in args else 0.95
        self.alpha_decay_rate = \
            args['alpha_decay'] if 'alpha_decay' in args else 0.995
        self.exploration_decay_rate = \
            args['epsilon_decay'] if 'epsilon_decay' in args else 0.9995

        self.IS_GREEDY = False

        self.policy_path = None

        self.weights = None
        self.sess = None

        # System and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        if self.agent_role == 'system':
            # Put your system expert policy here
            self.warmup_policy = HandcraftedPolicy({'ontology': self.ontology})

        elif self.agent_role == 'user':
            usim_args = \
                dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))
            # Put your user expert policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        self.tf_scope = "policy_" + self.agent_role + '_' + str(self.agent_id)

        # Default value
        self.is_training = True

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        if not domain:
            # Default to CamRest dimensions
            self.NStateFeatures = 56

            # Default to CamRest actions
            self.dstc2_acts = [
                'inform', 'offer', 'request', 'canthelp', 'affirm', 'negate',
                'deny', 'ack', 'thankyou', 'bye', 'reqmore', 'hello',
                'welcomemsg', 'expl-conf', 'select', 'repeat', 'reqalts',
                'confirm-domain', 'confirm'
            ]
        else:
            # Try to identify number of state features
            if domain in ['CamRest', 'SFH', 'SlotFilling']:
                d_state = \
                    SlotFillingDialogueState(
                        {'slots': self.system_requestable_slots})

                # Plato does not use action masks (rules to define which
                # actions are valid from each state) and so training can
                # be harder. This becomes easier if we have a smaller
                # action set.

                # Sub-case for CamRest
                if domain == 'CamRest':
                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_sys = [
                        'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye',
                        'reqmore', 'welcomemsg', 'expl-conf', 'select',
                        'repeat', 'confirm-domain', 'confirm'
                    ]

                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_usr = [
                        'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye',
                        'reqmore', 'hello', 'expl-conf', 'repeat', 'reqalts',
                        'restart', 'confirm'
                    ]

            else:
                print('Warning! domain has not been defined. Using '
                      'Slot-Filling dialogue State')
                d_state = \
                    SlotFillingDialogueState({'slots': self.informable_slots})

            d_state.initialize()
            self.NStateFeatures = len(self.encode_state(d_state))

            print('Reinforce policy {0} automatically determined '
                  'number of state features: {1}'.format(
                      self.agent_role, self.NStateFeatures))

        if domain == 'CamRest' and self.dstc2_acts_sys:
            if self.agent_role == 'system':
                self.NActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots)

            elif self.agent_role == 'user':
                self.NActions = \
                    len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

        else:
            if self.agent_role == 'system':
                self.NActions = \
                    3 + len(self.system_requestable_slots) + \
                    len(self.requestable_slots)

                self.NOtherActions = \
                    2 + len(self.requestable_slots) +\
                    len(self.requestable_slots)

            elif self.agent_role == 'user':
                self.NActions = \
                    2 + len(self.requestable_slots) + \
                    len(self.requestable_slots)

                self.NOtherActions = \
                    3 + len(self.system_requestable_slots) + \
                    len(self.requestable_slots)

        print('Reinforce {0} policy Number of Actions: {1}'.format(
            self.agent_role, self.NActions))
예제 #7
0
    def __init__(self, args):
        """
        Initialize parameters and internal structures

        :param args: dictionary containing the dialogue_policy's settings
        """

        super(WoLFPHCPolicy, self).__init__()

        self.ontology = None
        if 'ontology' in args:
            ontology = args['ontology']

            if isinstance(ontology, Ontology):
                self.ontology = ontology
            elif isinstance(ontology, str):
                self.ontology = Ontology(ontology)
            else:
                raise ValueError('WoLFPHCPolicy Unacceptable '
                                 'ontology type %s ' % ontology)
        else:
            raise ValueError('WoLFPHCPolicy: No ontology provided')

        self.database = None
        if 'database' in args:
            database = args['database']

            if isinstance(database, DataBase):
                self.database = database
            elif isinstance(database, str):
                self.database = DataBase(database)
            else:
                raise ValueError('WoLFPHCPolicy: Unacceptable '
                                 'database type %s ' % database)
        else:
            raise ValueError('WoLFPHCPolicy: No database provided')

        self.agent_role = \
            args['agent_role'] if 'agent_role' in args else 'system'

        self.alpha = args['alpha'] if 'alpha' in args else 0.2
        self.gamma = args['gamma'] if 'gamma' in args else 0.95
        self.epsilon = args['epsilon'] if 'epsilon' in args else 0.95
        self.alpha_decay_rate = \
            args['alpha_decay'] if 'alpha_decay' in args else 0.995
        self.exploration_decay_rate = \
            args['epsilon_decay'] if 'epsilon_decay' in args else 0.9995

        self.IS_GREEDY_POLICY = False

        # TODO: Put these as arguments in the config
        self.d_win = 0.0025
        self.d_lose = 0.01

        self.is_training = False

        self.Q = {}
        self.pi = {}
        self.mean_pi = {}
        self.state_counter = {}

        self.pp = pprint.PrettyPrinter(width=160)  # For debug!

        # System and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        if self.agent_role == 'system':
            # Put your system expert dialogue_policy here
            self.warmup_policy = \
                slot_filling_policy.HandcraftedPolicy({
                    'ontology': self.ontology})

        elif self.agent_role == 'user':
            usim_args = dict(
                zip(['ontology', 'database'], [self.ontology, self.database]))
            # Put your user expert dialogue_policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        # Sub-case for CamRest
        self.dstc2_acts_sys = self.dstc2_acts_usr = None

        # Plato does not use action masks (rules to define which
        # actions are valid from each state) and so training can
        # be harder. This becomes easier if we have a smaller
        # action set.

        # Does not include inform and request that are modelled together with
        # their arguments
        self.dstc2_acts_sys = [
            'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye', 'reqmore',
            'welcomemsg', 'expl-conf', 'select', 'repeat', 'confirm-domain',
            'confirm'
        ]

        # Does not include inform and request that are modelled together with
        # their arguments
        self.dstc2_acts_usr = [
            'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye', 'reqmore',
            'hello', 'expl-conf', 'repeat', 'reqalts', 'restart', 'confirm'
        ]

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        if self.dstc2_acts_sys:
            if self.agent_role == 'system':
                # self.NActions = 5
                # self.NOtherActions = 4
                self.NActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

            elif self.agent_role == 'user':
                # self.NActions = 4
                # self.NOtherActions = 5
                self.NActions = \
                    len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots) +\
                    len(self.system_requestable_slots)

                self.NOtherActions = len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)
        else:
            if self.agent_role == 'system':
                self.NActions = \
                    5 + len(self.ontology.ontology['system_requestable']) + \
                    len(self.ontology.ontology['requestable'])
                self.NOtherActions = \
                    4 + 2 * len(self.ontology.ontology['requestable'])

            elif self.agent_role == 'user':
                self.NActions = \
                    4 + 2 * len(self.ontology.ontology['requestable'])
                self.NOtherActions = \
                    5 + len(self.ontology.ontology['system_requestable']) + \
                    len(self.ontology.ontology['requestable'])

        self.statistics = {'supervised_turns': 0, 'total_turns': 0}
예제 #8
0
class WoLFPHCPolicy(dialogue_policy.DialoguePolicy):
    def __init__(self, args):
        """
        Initialize parameters and internal structures

        :param args: dictionary containing the dialogue_policy's settings
        """

        super(WoLFPHCPolicy, self).__init__()

        self.ontology = None
        if 'ontology' in args:
            ontology = args['ontology']

            if isinstance(ontology, Ontology):
                self.ontology = ontology
            elif isinstance(ontology, str):
                self.ontology = Ontology(ontology)
            else:
                raise ValueError('WoLFPHCPolicy Unacceptable '
                                 'ontology type %s ' % ontology)
        else:
            raise ValueError('WoLFPHCPolicy: No ontology provided')

        self.database = None
        if 'database' in args:
            database = args['database']

            if isinstance(database, DataBase):
                self.database = database
            elif isinstance(database, str):
                self.database = DataBase(database)
            else:
                raise ValueError('WoLFPHCPolicy: Unacceptable '
                                 'database type %s ' % database)
        else:
            raise ValueError('WoLFPHCPolicy: No database provided')

        self.agent_role = \
            args['agent_role'] if 'agent_role' in args else 'system'

        self.alpha = args['alpha'] if 'alpha' in args else 0.2
        self.gamma = args['gamma'] if 'gamma' in args else 0.95
        self.epsilon = args['epsilon'] if 'epsilon' in args else 0.95
        self.alpha_decay_rate = \
            args['alpha_decay'] if 'alpha_decay' in args else 0.995
        self.exploration_decay_rate = \
            args['epsilon_decay'] if 'epsilon_decay' in args else 0.9995

        self.IS_GREEDY_POLICY = False

        # TODO: Put these as arguments in the config
        self.d_win = 0.0025
        self.d_lose = 0.01

        self.is_training = False

        self.Q = {}
        self.pi = {}
        self.mean_pi = {}
        self.state_counter = {}

        self.pp = pprint.PrettyPrinter(width=160)  # For debug!

        # System and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        if self.agent_role == 'system':
            # Put your system expert dialogue_policy here
            self.warmup_policy = \
                slot_filling_policy.HandcraftedPolicy({
                    'ontology': self.ontology})

        elif self.agent_role == 'user':
            usim_args = dict(
                zip(['ontology', 'database'], [self.ontology, self.database]))
            # Put your user expert dialogue_policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        # Sub-case for CamRest
        self.dstc2_acts_sys = self.dstc2_acts_usr = None

        # Plato does not use action masks (rules to define which
        # actions are valid from each state) and so training can
        # be harder. This becomes easier if we have a smaller
        # action set.

        # Does not include inform and request that are modelled together with
        # their arguments
        self.dstc2_acts_sys = [
            'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye', 'reqmore',
            'welcomemsg', 'expl-conf', 'select', 'repeat', 'confirm-domain',
            'confirm'
        ]

        # Does not include inform and request that are modelled together with
        # their arguments
        self.dstc2_acts_usr = [
            'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye', 'reqmore',
            'hello', 'expl-conf', 'repeat', 'reqalts', 'restart', 'confirm'
        ]

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        if self.dstc2_acts_sys:
            if self.agent_role == 'system':
                # self.NActions = 5
                # self.NOtherActions = 4
                self.NActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

            elif self.agent_role == 'user':
                # self.NActions = 4
                # self.NOtherActions = 5
                self.NActions = \
                    len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots) +\
                    len(self.system_requestable_slots)

                self.NOtherActions = len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)
        else:
            if self.agent_role == 'system':
                self.NActions = \
                    5 + len(self.ontology.ontology['system_requestable']) + \
                    len(self.ontology.ontology['requestable'])
                self.NOtherActions = \
                    4 + 2 * len(self.ontology.ontology['requestable'])

            elif self.agent_role == 'user':
                self.NActions = \
                    4 + 2 * len(self.ontology.ontology['requestable'])
                self.NOtherActions = \
                    5 + len(self.ontology.ontology['system_requestable']) + \
                    len(self.ontology.ontology['requestable'])

        self.statistics = {'supervised_turns': 0, 'total_turns': 0}

    def initialize(self, args):
        """
        Initialize internal structures at the beginning of each dialogue

        :return: Nothing
        """

        if 'train' in args:
            self.is_training = bool(args['train'])

            if 'learning_rate' in args:
                self.alpha = float(args['learning_rate'])

            if 'learning_decay_rate' in args:
                self.alpha_decay_rate = float(args['learning_decay_rate'])

            if 'exploration_rate' in args:
                self.epsilon = float(args['exploration_rate'])

            if 'exploration_decay_rate' in args:
                self.exploration_decay_rate = \
                    float(args['exploration_decay_rate'])

            if 'gamma' in args:
                self.gamma = float(args['gamma'])

            if self.agent_role == 'user' and self.warmup_simulator:
                if 'goal' in args:
                    self.warmup_simulator.initialize({args['goal']})
                else:
                    print('WARNING ! No goal provided for WoLF PHC policy '
                          'user simulator @ initialize')
                    self.warmup_simulator.initialize({})

    def restart(self, args):
        """
        Re-initialize relevant parameters / variables at the beginning of each
        dialogue.

        :return: nothing
        """

        if self.agent_role == 'user' and self.warmup_simulator:
            if 'goal' in args:
                self.warmup_simulator.initialize(args)
            else:
                print('WARNING! No goal provided for WoLF PHC policy user '
                      'simulator @ restart')
                self.warmup_simulator.initialize({})

    def next_action(self, state):
        """
        Consults the dialogue_policy to produce the agent's response

        :param state: the current dialogue state
        :return: a list of dialogue acts, representing the agent's response
        """

        state_enc = self.encode_state(state)
        self.statistics['total_turns'] += 1

        if state_enc not in self.pi or \
                (self.is_training and random.random() < self.epsilon):
            if not self.is_training:
                if not self.pi:
                    print(f'\nWARNING! WoLF-PHC pi is empty '
                          f'({self.agent_role}). Did you load the correct '
                          f'file?\n')
                else:
                    print(f'\nWARNING! WoLF-PHC state not found in policy '
                          f'pi ({self.agent_role}).\n')

            if random.random() < 0.35:
                print('--- {0}: Selecting warmup action.'.format(
                    self.agent_role))
                self.statistics['supervised_turns'] += 1

                if self.agent_role == 'system':
                    return self.warmup_policy.next_action(state)

                else:
                    self.warmup_simulator.receive_input(
                        state.user_acts, state.user_goal)
                    return self.warmup_simulator.respond()
            else:
                print('--- {0}: Selecting random action.'.format(
                    self.agent_role))
                return self.decode_action(
                    random.choice(range(0, self.NActions)),
                    self.agent_role == 'system')

        if self.IS_GREEDY_POLICY:
            # Get greedy action
            max_pi = max(self.pi[state_enc][:-1])  # Do not consider 'UNK'
            maxima = \
                [i for i, j in enumerate(self.pi[state_enc]) if j == max_pi]

            # Break ties randomly
            if maxima:
                sys_acts = \
                    self.decode_action(random.choice(maxima),
                                       self.agent_role == 'system')
            else:
                print('--- {0}: Warning! No maximum value identified for '
                      'dialogue policy. Selecting random action.'.format(
                          self.agent_role))

                return self.decode_action(
                    random.choice(range(0, self.NActions)),
                    self.agent_role == 'system')
        else:
            # Sample next action
            sys_acts = \
                self.decode_action(
                    random.choices(range(len(self.pi[state_enc])),
                                   self.pi[state_enc])[0],
                    self.agent_role == 'system')

        return sys_acts

    def encode_state(self, state):
        """
        Encodes the dialogue state into an index used to address the Q matrix.

        :param state: the state to encode
        :return: int - a unique state encoding
        """

        temp = [int(state.is_terminal_state)]

        temp.append(1) if state.system_made_offer else temp.append(0)

        if self.agent_role == 'user':
            # The user agent needs to know which constraints and requests
            # need to be communicated and which of them
            # actually have.
            if state.user_goal:
                for c in self.informable_slots:
                    if c != 'name':
                        if c in state.user_goal.constraints and \
                                state.user_goal.constraints[c].value:
                            temp.append(1)
                        else:
                            temp.append(0)

                        if c in state.user_goal.actual_constraints and \
                                state.user_goal.actual_constraints[c].value:
                            temp.append(1)
                        else:
                            temp.append(0)

                for r in self.requestable_slots:
                    if r in state.user_goal.requests:
                        temp.append(1)
                    else:
                        temp.append(0)

                    if r in state.user_goal.actual_requests and \
                            state.user_goal.actual_requests[r].value:
                        temp.append(1)
                    else:
                        temp.append(0)

            else:
                temp += \
                    [0] * 2*(len(self.informable_slots)-1 +
                             len(self.requestable_slots))

        if self.agent_role == 'system':
            for value in state.slots_filled.values():
                # This contains the requested slot
                temp.append(1) if value else temp.append(0)

            for r in self.requestable_slots:
                temp.append(1) if r == state.requested_slot else temp.append(0)

        # Encode state
        state_enc = 0
        for t in temp:
            state_enc = (state_enc << 1) | t

        return state_enc

    def encode_action(self, actions, system=True):
        """
        Encode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be encoding another agent's action
        (e.g. a system encoding the previous user act).

        :param actions: actions to be encoded
        :param system: whether the role whose action we are encoding is a
                       'system'
        :return: the encoded action
        """

        # TODO: Handle multiple actions
        if not actions:
            print('WARNING: WoLF-PHC dialogue_policy action encoding called '
                  'with empty actions list (returning -1).')
            return -1

        action = actions[0]

        if system:
            if self.dstc2_acts_sys and action.intent in self.dstc2_acts_sys:
                return self.dstc2_acts_sys.index(action.intent)

            if action.intent == 'request':
                if action.params[0].slot not in self.system_requestable_slots:
                    return -1

                return len(self.dstc2_acts_sys) + \
                       self.system_requestable_slots.index(
                           action.params[0].slot)

            if action.intent == 'inform':
                if action.params[0].slot not in self.requestable_slots:
                    return -1

                return len(self.dstc2_acts_sys) + \
                       len(self.system_requestable_slots) + \
                       self.requestable_slots.index(action.params[0].slot)
        else:
            if self.dstc2_acts_usr and action.intent in self.dstc2_acts_usr:
                return self.dstc2_acts_usr.index(action.intent)

            if action.intent == 'request':
                if action.params[0].slot not in self.requestable_slots:
                    return -1

                return len(self.dstc2_acts_usr) + \
                       self.requestable_slots.index(action.params[0].slot)

            if action.intent == 'inform':
                if action.params[0].slot not in self.system_requestable_slots:
                    return -1

                return len(self.dstc2_acts_usr) + \
                       len(self.requestable_slots) + \
                       self.system_requestable_slots.index(
                           action.params[0].slot)

        if (self.agent_role == 'system') == system:
            print('WoLF-PHC ({0}) policy action encoder warning: Selecting '
                  'default action (unable to encode: {1})!'.format(
                      self.agent_role, action))
        else:
            print('WoLF-PHC ({0}) policy action encoder warning: Selecting '
                  'default action (unable to encode other agent action: {1})!'.
                  format(self.agent_role, action))

        return -1

    def decode_action(self, action_enc, system=True):
        """
        Decode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be decoding another agent's action
        (e.g. a system decoding the previous user act).

        :param action_enc: action encoding to be decoded
        :param system: whether the role whose action we are decoding is a
                       'system'
        :return: the decoded action
        """

        if system:
            if action_enc < len(self.dstc2_acts_sys):
                return [DialogueAct(self.dstc2_acts_sys[action_enc], [])]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.system_requestable_slots[
                                action_enc - len(self.dstc2_acts_sys)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots) + \
                    len(self.requestable_slots):
                index = \
                    action_enc - len(self.dstc2_acts_sys) - \
                    len(self.system_requestable_slots)
                return [
                    DialogueAct('inform', [
                        DialogueActItem(self.requestable_slots[index],
                                        Operator.EQ, '')
                    ])
                ]

        else:
            if action_enc < len(self.dstc2_acts_usr):
                return [DialogueAct(self.dstc2_acts_usr[action_enc], [])]

            if action_enc < len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.requestable_slots[action_enc -
                                                   len(self.dstc2_acts_usr)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots):
                return [
                    DialogueAct('inform', [
                        DialogueActItem(
                            self.system_requestable_slots[
                                action_enc - len(self.dstc2_acts_usr) -
                                len(self.requestable_slots)], Operator.EQ, '')
                    ])
                ]

        # Default fall-back action
        print('WoLF-PHC dialogue_policy ({0}) policy action decoder warning: '
              'Selecting repeat() action (index: {1})!'.format(
                  self.agent_role, action_enc))
        return [DialogueAct('repeat', [])]

    def train(self, dialogues):
        """
        Train the model using WoLF-PHC.

        :param dialogues: a list dialogues, which is a list of dialogue turns
                         (state, action, reward triplets).
        :return:
        """

        if not self.is_training:
            return

        for dialogue in dialogues:
            if len(dialogue) > 1:
                dialogue[-2]['reward'] = dialogue[-1]['reward']

            for turn in dialogue:
                state_enc = self.encode_state(turn['state'])
                new_state_enc = self.encode_state(turn['new_state'])

                role = self.agent_role
                if 'role' in turn:
                    role = turn['role']

                action_enc = \
                    self.encode_action(
                        turn['action'],
                        role == 'system')

                # Skip unrecognised actions
                if action_enc < 0 or turn['action'][0].intent == 'bye':
                    continue

                if state_enc not in self.Q:
                    self.Q[state_enc] = [0] * self.NActions

                if new_state_enc not in self.Q:
                    self.Q[new_state_enc] = [0] * self.NActions

                if state_enc not in self.pi:
                    self.pi[state_enc] = \
                        [float(1/self.NActions)] * self.NActions

                if state_enc not in self.mean_pi:
                    self.mean_pi[state_enc] = \
                        [float(1/self.NActions)] * self.NActions

                if state_enc not in self.state_counter:
                    self.state_counter[state_enc] = 1
                else:
                    self.state_counter[state_enc] += 1

                # Update Q
                self.Q[state_enc][action_enc] = \
                    ((1 - self.alpha) * self.Q[state_enc][action_enc]) + \
                    self.alpha * (
                            turn['reward'] +
                            (self.gamma * np.max(self.Q[new_state_enc])))

                # Update mean dialogue_policy estimate
                for a in range(self.NActions):
                    self.mean_pi[state_enc][a] = \
                        self.mean_pi[state_enc][a] + \
                        ((1.0 / self.state_counter[state_enc]) *
                         (self.pi[state_enc][a] - self.mean_pi[state_enc][a]))

                # Determine delta
                sum_policy = 0.0
                sum_mean_policy = 0.0

                for a in range(self.NActions):
                    sum_policy = sum_policy + (self.pi[state_enc][a] *
                                               self.Q[state_enc][a])
                    sum_mean_policy = \
                        sum_mean_policy + \
                        (self.mean_pi[state_enc][a] * self.Q[state_enc][a])

                if sum_policy > sum_mean_policy:
                    delta = self.d_win
                else:
                    delta = self.d_lose

                # Update dialogue_policy estimate
                max_q_idx = np.argmax(self.Q[state_enc])

                d_plus = delta
                d_minus = ((-1.0) * d_plus) / (self.NActions - 1.0)

                for a in range(self.NActions):
                    if a == max_q_idx:
                        self.pi[state_enc][a] = \
                            min(1.0, self.pi[state_enc][a] + d_plus)
                    else:
                        self.pi[state_enc][a] = \
                            max(0.0, self.pi[state_enc][a] + d_minus)

                # Constrain pi to a legal probability distribution
                sum_pi = sum(self.pi[state_enc])
                for a in range(self.NActions):
                    self.pi[state_enc][a] /= sum_pi

        # Decay learning rate after each episode
        if self.alpha > 0.001:
            self.alpha *= self.alpha_decay_rate

        # Decay exploration rate after each episode
        if self.epsilon > 0.25:
            self.epsilon *= self.exploration_decay_rate

        print('[alpha: {0}, epsilon: {1}]'.format(self.alpha, self.epsilon))

    def save(self, path=None):
        """
        Saves the dialogue_policy model to the path provided

        :param path: path to save the model to
        :return:
        """

        # Don't save if not training
        if not self.is_training:
            return

        if not path:
            path = 'models/policies/wolf_phc_policy.pkl'
            print('No dialogue_policy file name provided. Using default: {0}'.
                  format(path))

        # If the directory does not exist, create it
        if not os.path.exists(os.path.dirname(path)):
            os.makedirs(os.path.dirname(path), exist_ok=True)

        obj = {
            'Q': self.Q,
            'pi': self.pi,
            'mean_pi': self.mean_pi,
            'state_counter': self.state_counter,
            'a': self.alpha,
            'e': self.epsilon,
            'g': self.gamma
        }

        with open(path, 'wb') as file:
            pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)

        if self.statistics['total_turns'] > 0:
            print(
                'DEBUG > {0} WoLF PHC dialogue_policy supervision ratio: {1}'.
                format(
                    self.agent_role,
                    float(self.statistics['supervised_turns'] /
                          self.statistics['total_turns'])))

        print(f'DEBUG > {self.agent_role} WoLF PHC policy state space '
              f'size: {len(self.pi)}')

    def load(self, path=None):
        """
        Load the dialogue_policy model from the path provided

        :param path: path to load the model from
        :return:
        """

        if not path:
            print('No dialogue_policy loaded.')
            return

        if isinstance(path, str):
            if os.path.isfile(path):
                with open(path, 'rb') as file:
                    obj = pickle.load(file)

                    if 'Q' in obj:
                        self.Q = obj['Q']
                    if 'pi' in obj:
                        self.pi = obj['pi']
                    if 'mean_pi' in obj:
                        self.mean_pi = obj['mean_pi']
                    if 'state_counter' in obj:
                        self.state_counter = obj['state_counter']
                    if 'a' in obj:
                        self.alpha = obj['a']
                    if 'e' in obj:
                        self.epsilon = obj['e']
                    if 'g' in obj:
                        self.gamma = obj['g']

                    print('WoLF-PHC dialogue_policy loaded from {0}.'.format(
                        path))

            else:
                print('Warning! WoLF-PHC dialogue_policy file %s not found' %
                      path)
        else:
            print('Warning! Unacceptable value for WoLF-PHC policy file name:'
                  ' %s ' % path)
예제 #9
0
class MinimaxQPolicy(dialogue_policy.DialoguePolicy):
    def __init__(self, args):
        """
        Initialize parameters and internal structures

        :param args: the policy's arguments
        """

        super(MinimaxQPolicy, self).__init__()

        self.ontology = None
        if 'ontology' in args:
            ontology = args['ontology']

            if isinstance(ontology, Ontology):
                self.ontology = ontology
            else:
                raise ValueError('MinimaxQPolicy dialogue_policy Unacceptable '
                                 'ontology type %s ' % ontology)
        else:
            raise ValueError('MinimaxQPolicy dialogue_policy: No ontology '
                             'provided')

        self.database = None
        if 'database' in args:
            database = args['database']

            if isinstance(database, DataBase):
                self.database = database
            else:
                raise ValueError('MinimaxQPolicy policy: Unacceptable '
                                 'database type %s ' % database)
        else:
            raise ValueError('MinimaxQPolicy policy: No database ' 'provided')

        self.agent_id = args['agent_id'] if 'agent_id' in args else 0
        self.agent_role = \
            args['agent_role'] if 'agent_role' in args else 'system'

        self.alpha = args['alpha'] if 'alpha' in args else 0.2
        self.gamma = args['gamma'] if 'gamma' in args else 0.95
        self.epsilon = args['epsilon'] if 'epsilon' in args else 0.95
        self.alpha_decay = \
            args['alpha_decay'] if 'alpha_decay' in args else 0.995
        self.epsilon_decay = \
            args['epsilon_decay'] if 'epsilon_decay' in args else 0.9995

        self.is_training = False

        self.Q = {}
        self.V = {}
        self.pi = {}

        self.pp = pprint.PrettyPrinter(width=160)  # For debug!

        # System and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        if self.agent_role == 'system':
            # Put your system expert dialogue_policy here
            self.warmup_policy = \
                slot_filling_policy.HandcraftedPolicy({
                    'ontology': self.ontology})

        elif self.agent_role == 'user':
            usim_args = \
                dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))
            # Put your user expert dialogue_policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        # Sub-case for CamRest
        self.dstc2_acts_sys = self.dstc2_acts_usr = None

        # Plato does not use action masks (rules to define which
        # actions are valid from each state) and so training can
        # be harder. This becomes easier if we have a smaller
        # action set.

        # Does not include inform and request that are modelled together with
        # their arguments
        self.dstc2_acts_sys = [
            'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye', 'reqmore',
            'welcomemsg', 'expl-conf', 'select', 'repeat', 'confirm-domain',
            'confirm'
        ]

        # Does not include inform and request that are modelled together with
        # their arguments
        self.dstc2_acts_usr = [
            'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye', 'reqmore',
            'hello', 'expl-conf', 'repeat', 'reqalts', 'restart', 'confirm'
        ]

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        if self.dstc2_acts_sys:
            if self.agent_role == 'system':
                self.NActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots)

            elif self.agent_role == 'user':
                self.NActions = \
                    len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)
        else:
            if self.agent_role == 'system':
                self.NActions = \
                    5 + \
                    len(self.ontology.ontology['system_requestable']) + \
                    len(self.ontology.ontology['requestable'])

                self.NOtherActions = \
                    4 + 2 * len(self.ontology.ontology['requestable'])

            elif self.agent_role == 'user':
                self.NActions = \
                    4 + 2 * len(self.ontology.ontology['requestable'])
                self.NOtherActions = \
                    5 + len(self.ontology.ontology['system_requestable']) + \
                    len(self.ontology.ontology['requestable'])

    def initialize(self, args):
        """
        Initialize internal parameters

        :return: Nothing
        """

        if 'is_training' in args:
            self.is_training = bool(args['is_training'])

            if self.agent_role == 'user' and self.warmup_simulator:
                if 'goal' in args:
                    self.warmup_simulator.initialize({args['goal']})
                else:
                    print('WARNING ! No goal provided for Minimax Q policy '
                          'user simulator @ initialize')
                    self.warmup_simulator.initialize({})

    def restart(self, args):
        """
        Re-initialize relevant parameters / variables at the beginning of each
        dialogue.

        :return:
        """

        if self.agent_role == 'user' and self.warmup_simulator:
            if 'goal' in args:
                self.warmup_simulator.initialize(args)
            else:
                print('WARNING! No goal provided for Minimax Q policy user '
                      'simulator @ restart')
                self.warmup_simulator.initialize({})

    def next_action(self, state):
        """
        Consults the dialogue_policy to produce the agent's response

        :param state: the current dialogue state
        :return: a list of dialogue acts, representing the agent's response
        """

        state_enc = self.encode_state(state)

        if state_enc not in self.pi or \
                (self.is_training and random.random() < self.epsilon):
            if not self.is_training:
                if not self.pi:
                    print(f'\nWARNING! Minimax Q {self.agent_role} matrix is '
                          f'empty. Did you load the correct file?\n')
                else:
                    print(f'\nWARNING! Minimax Q {self.agent_role} state not '
                          f'found in Q matrix.\n')

            if random.random() < 0.5:
                print('--- {0}: Selecting warmup action.'.format(
                    self.agent_role))

                if self.agent_role == 'system':
                    return self.warmup_policy.next_action(state)

                else:
                    self.warmup_simulator.receive_input(
                        state.user_acts, state.user_goal)
                    return self.warmup_simulator.respond()

            else:
                print('--- {0}: Selecting random action.'.format(
                    self.agent_role))
                return self.decode_action(
                    random.choice(range(0, self.NActions)),
                    self.agent_role == 'system')

        # Return best action
        max_pi = max(self.pi[state_enc])
        maxima = [i for i, j in enumerate(self.pi[state_enc]) if j == max_pi]

        # Break ties randomly
        if maxima:
            sys_acts = \
                self.decode_action(
                    random.choice(maxima), self.agent_role == 'system')
        else:
            print('--- {0}: Warning! No maximum value identified for policy. '
                  'Selecting random action.'.format(self.agent_role))
            return self.decode_action(random.choice(range(0, self.NActions)),
                                      self.agent_role == 'system')

        return sys_acts

    def encode_state(self, state):
        """
        Encodes the dialogue state into an index used to address the Q matrix.

        :param state: the state to encode
        :return: int - a unique state encoding
        """
        temp = [int(state.is_terminal_state)]

        temp.append(1) if state.system_made_offer else temp.append(0)

        # If the agent plays the role of the user it needs access to its own
        # goal
        if self.agent_role == 'user':
            # The user agent needs to know which constraints and requests need
            # to be communicated and which of them
            # actually have.
            if state.user_goal:
                found_unanswered_constr = False
                found_unanswered_req = False

                for c in self.informable_slots:
                    if c != 'name':
                        if c in state.user_goal.constraints and \
                                c not in state.user_goal.actual_constraints:
                            found_unanswered_constr = True
                            break

                for r in self.requestable_slots:
                    if r in state.user_goal.requests and \
                            not state.user_goal.requests[r].value:
                        found_unanswered_req = True
                        break

                temp += \
                    [int(found_unanswered_constr), int(found_unanswered_req)]
            else:
                temp += [0, 0]

        if self.agent_role == 'system':
            temp.append(int(state.is_terminal()))
            temp.append(int(state.system_made_offer))

            for value in state.slots_filled.values():
                # This contains the requested slot
                temp.append(1) if value else temp.append(0)

            for r in self.requestable_slots:
                temp.append(1) if r == state.requested_slot else temp.append(0)

        state_enc = 0
        for t in temp:
            state_enc = (state_enc << 1) | t

        return state_enc

    def encode_action(self, actions, system=True):
        """
        Encode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be encoding another agent's action
        (e.g. a system encoding the previous user act).

        :param actions: actions to be encoded
        :param system: whether the role whose action we are encoding is a
                       'system'
        :return: the encoded action
        """

        # TODO: Handle multiple actions
        if not actions:
            print('WARNING: MinimaxQ policy action encoding called '
                  'with empty actions list (returning -1).')
            return -1

        action = actions[0]

        if system:
            if self.dstc2_acts_sys and action.intent in self.dstc2_acts_sys:
                return self.dstc2_acts_sys.index(action.intent)

            if action.intent == 'request':
                return len(self.dstc2_acts_sys) + \
                       self.system_requestable_slots.index(
                           action.params[0].slot)

            if action.intent == 'inform':
                return len(self.dstc2_acts_sys) + \
                       len(self.system_requestable_slots) + \
                       self.requestable_slots.index(
                           action.params[0].slot)
        else:
            if self.dstc2_acts_usr and action.intent in self.dstc2_acts_usr:
                return self.dstc2_acts_usr.index(action.intent)

            if action.intent == 'request':
                return len(self.dstc2_acts_usr) + \
                       self.requestable_slots.index(action.params[0].slot)

            if action.intent == 'inform':
                return len(self.dstc2_acts_usr) + \
                       len(self.requestable_slots) + \
                       self.system_requestable_slots.index(
                           action.params[0].slot)

        if (self.agent_role == 'system') == system:
            print('MinimaxQ ({0}) policy action encoder warning: Selecting '
                  'default action (unable to encode: {1})!'.format(
                      self.agent_role, action))
        else:
            print('MinimaxQ ({0}) policy action encoder warning: Selecting '
                  'default action (unable to encode other agent action: '
                  '{1})!'.format(self.agent_role, action))

        return -1

    def decode_action(self, action_enc, system=True):
        """
        Decode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be decoding another agent's action
        (e.g. a system decoding the previous user act).

        :param action_enc: action encoding to be decoded
        :param system: whether the role whose action we are decoding is a
                       'system'
        :return: the decoded action
        """

        if system:
            if action_enc < len(self.dstc2_acts_sys):
                return [DialogueAct(self.dstc2_acts_sys[action_enc], [])]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.system_requestable_slots[
                                action_enc - len(self.dstc2_acts_sys)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots) + \
                    len(self.requestable_slots):
                index = \
                    action_enc - len(self.dstc2_acts_sys) - \
                    len(self.system_requestable_slots)
                return [
                    DialogueAct('inform', [
                        DialogueActItem(self.requestable_slots[index],
                                        Operator.EQ, '')
                    ])
                ]

        else:
            if action_enc < len(self.dstc2_acts_usr):
                return [DialogueAct(self.dstc2_acts_usr[action_enc], [])]

            if action_enc < len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.requestable_slots[action_enc -
                                                   len(self.dstc2_acts_usr)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots):
                return [
                    DialogueAct('inform', [
                        DialogueActItem(
                            self.system_requestable_slots[
                                action_enc - len(self.dstc2_acts_usr) -
                                len(self.requestable_slots)], Operator.EQ, '')
                    ])
                ]

        # Default fall-back action
        print('MinimaxQ dialogue_policy ({0}) policy action decoder warning: '
              'Selecting repeat() action '
              '(index: {1})!'.format(self.agent_role, action_enc))
        return [DialogueAct('repeat', [])]

    def train(self, dialogues):
        """
        Train the model using MinimaxQ.

        :param dialogues: a list dialogues, which is a list of dialogue turns
                          (state, action, reward triplets).
        :return:
        """

        if not self.is_training:
            return

        for dialogue in dialogues:
            if len(dialogue) > 1:
                dialogue[-2]['reward'] = dialogue[-1]['reward']

            for turn in dialogue:
                state_enc = self.encode_state(turn['state'])
                new_state_enc = self.encode_state(turn['new_state'])
                action_enc = \
                    self.encode_action(
                        turn['action'],
                        self.agent_role == 'system')
                other_action_enc = \
                    self.encode_action(
                        turn['state'].user_acts,
                        self.agent_role != 'system')

                if action_enc < 0 or other_action_enc < 0 or \
                        turn['action'][0].intent == 'bye':
                    continue

                if state_enc not in self.Q:
                    self.Q[state_enc] = []

                    for oa in range(self.NOtherActions):
                        self.Q[state_enc].append([])

                        for a in range(self.NActions):
                            self.Q[state_enc][oa].append(1)

                if state_enc not in self.pi:
                    self.pi[state_enc] = float(1 / self.NActions)

                if action_enc not in self.Q[state_enc][other_action_enc]:
                    self.Q[state_enc][other_action_enc][action_enc] = 0

                if new_state_enc not in self.V:
                    self.V[new_state_enc] = 0

                if new_state_enc not in self.pi:
                    self.pi[new_state_enc] = float(1 / self.NActions)

                delta = turn['reward'] + self.gamma * self.V[new_state_enc]

                # Only update Q values (actor) that lead to an increase in Q
                # if delta > self.Q[state_enc][other_action_enc][action_enc]:
                self.Q[state_enc][other_action_enc][action_enc] += \
                    self.alpha * delta

                # Update V (critic)
                self.V[state_enc] = self.maxmin(state_enc)

        # Decay learning rate after each episode
        if self.alpha > 0.001:
            self.alpha *= self.alpha_decay

        # Decay exploration rate after each episode
        if self.epsilon > 0.25:
            self.epsilon *= self.epsilon_decay

        print('MiniMaxQ [alpha: {0}, epsilon: {1}]'.format(
            self.alpha, self.epsilon))

    def maxmin(self, state_enc, retry=False):
        """
        Solve the maxmin problem

        :param state_enc: the encoding to the state
        :param retry:
        :return:
        """

        c = np.zeros(self.NActions + 1)
        c[0] = -1
        A_ub = np.ones((self.NOtherActions, self.NActions + 1))
        A_ub[:, 1:] = -np.asarray(self.Q[state_enc])
        b_ub = np.zeros(self.NOtherActions)
        A_eq = np.ones((1, self.NActions + 1))
        A_eq[0, 0] = 0
        b_eq = [1]
        bounds = ((None, None), ) + ((0, 1), ) * self.NActions

        res = linprog(c,
                      A_ub=A_ub,
                      b_ub=b_ub,
                      A_eq=A_eq,
                      b_eq=b_eq,
                      bounds=bounds)

        if res.success:
            self.pi[state_enc] = res.x[1:]
        elif not retry:
            return self.maxmin(state_enc, retry=True)
        else:
            print("Alert : %s" % res.message)
            if state_enc in self.V:
                return self.V[state_enc]
            else:
                print('Warning, state not in V, returning 0.')
                return 0

        return res.x[0]

    def save(self, path=None):
        """
        Save the model in the path provided

        :param path: path to dave the model to
        :return: nothing
        """

        # Don't save if not training
        if not self.is_training:
            return

        if not path:
            path = 'models/policies/minimax_q_policy.pkl'
            print('No dialogue_policy file name provided. Using default: {0}'.
                  format(path))

        # If the directory does not exist, create it
        if not os.path.exists(os.path.dirname(path)):
            os.makedirs(os.path.dirname(path), exist_ok=True)

        obj = {
            'Q': self.Q,
            'V': self.V,
            'pi': self.pi,
            'a': self.alpha,
            'e': self.epsilon,
            'g': self.gamma
        }

        with open(path, 'wb') as file:
            pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)

    def load(self, path):
        """
        Load the model from the path provided

        :param path: path to load the model from
        :return: nothing
        """

        if not path:
            print('No dialogue_policy loaded.')
            return

        if isinstance(path, str):
            if os.path.isfile(path):
                with open(path, 'rb') as file:
                    obj = pickle.load(file)

                    if 'Q' in obj:
                        self.Q = obj['Q']
                    if 'V' in obj:
                        self.V = obj['V']
                    if 'pi' in obj:
                        self.pi = obj['pi']
                    if 'a' in obj:
                        self.alpha = obj['a']
                    if 'e' in obj:
                        self.epsilon = obj['e']
                    if 'g' in obj:
                        self.gamma = obj['g']

                    print('Q dialogue_policy loaded from {0}.'.format(path))

            else:
                print('Warning! Q dialogue_policy file %s not found' % path)
        else:
            print('Warning! Unacceptable value for Q policy file name: %s ' %
                  path)
    def __init__(self, args):
        """
        Initialize parameters and internal structures

        :param args: the policy's arguments
        """
        super(QPolicy, self).__init__()

        self.ontology = None
        if 'ontology' in args:
            ontology = args['ontology']

            if isinstance(ontology, Ontology):
                self.ontology = ontology
            else:
                raise ValueError('QPolicy dialogue policy Unacceptable '
                                 'ontology type %s ' % ontology)
        else:
            raise ValueError('QPolicy dialogue policy: No ontology '
                             'provided')

        self.database = None
        if 'database' in args:
            database = args['database']

            if isinstance(database, DataBase):
                self.database = database
            else:
                raise ValueError('QPolicy dialogue policy: Unacceptable '
                                 'database type %s ' % database)
        else:
            raise ValueError('QPolicy dialogue policy: No database '
                             'provided')

        domain = args['domain'] if 'domain' in args else None
        self.agent_id = args['agent_id'] if 'agent_id' in args else 0
        self.agent_role = \
            args['agent_role'] if 'agent_role' in args else 'system'

        self.alpha = args['alpha'] if 'alpha' in args else 0.2
        self.gamma = args['gamma'] if 'gamma' in args else 0.95
        self.epsilon = args['epsilon'] if 'epsilon' in args else 0.95
        self.alpha_decay = \
            args['alpha_decay'] if 'alpha_decay' in args else 0.995
        self.epsilon_decay = \
            args['epsilon_decay'] if 'epsilon_decay' in args else 0.9995

        self.is_training = False
        self.IS_GREEDY_POLICY = True

        self.Q = {}

        self.pp = pprint.PrettyPrinter(width=160)  # For debug!

        # System and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        if self.agent_role == 'system':
            # Put your system expert dialogue policy here
            self.warmup_policy = \
                slot_filling_policy.HandcraftedPolicy({
                    'ontology': self.ontology})

        elif self.agent_role == 'user':
            usim_args = \
                dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))
            # Put your user expert dialogue policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        self.dstc2_acts = None

        if not domain:
            # Default to CamRest dimensions
            self.NStateFeatures = 56

            # Default to CamRest actions
            self.dstc2_acts = [
                'repeat', 'canthelp', 'affirm', 'negate', 'deny', 'ack',
                'thankyou', 'bye', 'reqmore', 'hello', 'welcomemsg',
                'expl-conf', 'select', 'offer', 'reqalts', 'confirm-domain',
                'confirm'
            ]

        else:
            # Try to identify number of state features
            if domain in ['SlotFilling', 'CamRest']:

                # Plato does not use action masks (rules to define which
                # actions are valid from each state) and so training can
                # be harder. This becomes easier if we have a smaller
                # action set.

                # Sub-case for CamRest
                if domain == 'CamRest':
                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_sys = [
                        'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye',
                        'reqmore', 'welcomemsg', 'expl-conf', 'select',
                        'repeat', 'confirm-domain', 'confirm'
                    ]

                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_usr = [
                        'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye',
                        'reqmore', 'hello', 'expl-conf', 'repeat', 'reqalts',
                        'restart', 'confirm'
                    ]

                    if self.agent_role == 'system':
                        self.dstc2_acts = self.dstc2_acts_sys

                    elif self.agent_role == 'user':
                        self.dstc2_acts = self.dstc2_acts_usr

                    self.NActions = \
                        len(self.dstc2_acts) + len(self.requestable_slots)

                    if self.agent_role == 'system':
                        self.NActions += len(self.system_requestable_slots)
                    else:
                        self.NActions += len(self.requestable_slots)
class QPolicy(dialogue_policy.DialoguePolicy):
    def __init__(self, args):
        """
        Initialize parameters and internal structures

        :param args: the policy's arguments
        """
        super(QPolicy, self).__init__()

        self.ontology = None
        if 'ontology' in args:
            ontology = args['ontology']

            if isinstance(ontology, Ontology):
                self.ontology = ontology
            else:
                raise ValueError('QPolicy dialogue policy Unacceptable '
                                 'ontology type %s ' % ontology)
        else:
            raise ValueError('QPolicy dialogue policy: No ontology '
                             'provided')

        self.database = None
        if 'database' in args:
            database = args['database']

            if isinstance(database, DataBase):
                self.database = database
            else:
                raise ValueError('QPolicy dialogue policy: Unacceptable '
                                 'database type %s ' % database)
        else:
            raise ValueError('QPolicy dialogue policy: No database '
                             'provided')

        domain = args['domain'] if 'domain' in args else None
        self.agent_id = args['agent_id'] if 'agent_id' in args else 0
        self.agent_role = \
            args['agent_role'] if 'agent_role' in args else 'system'

        self.alpha = args['alpha'] if 'alpha' in args else 0.2
        self.gamma = args['gamma'] if 'gamma' in args else 0.95
        self.epsilon = args['epsilon'] if 'epsilon' in args else 0.95
        self.alpha_decay = \
            args['alpha_decay'] if 'alpha_decay' in args else 0.995
        self.epsilon_decay = \
            args['epsilon_decay'] if 'epsilon_decay' in args else 0.9995

        self.is_training = False
        self.IS_GREEDY_POLICY = True

        self.Q = {}

        self.pp = pprint.PrettyPrinter(width=160)  # For debug!

        # System and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        if self.agent_role == 'system':
            # Put your system expert dialogue policy here
            self.warmup_policy = \
                slot_filling_policy.HandcraftedPolicy({
                    'ontology': self.ontology})

        elif self.agent_role == 'user':
            usim_args = \
                dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))
            # Put your user expert dialogue policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        self.dstc2_acts = None

        if not domain:
            # Default to CamRest dimensions
            self.NStateFeatures = 56

            # Default to CamRest actions
            self.dstc2_acts = [
                'repeat', 'canthelp', 'affirm', 'negate', 'deny', 'ack',
                'thankyou', 'bye', 'reqmore', 'hello', 'welcomemsg',
                'expl-conf', 'select', 'offer', 'reqalts', 'confirm-domain',
                'confirm'
            ]

        else:
            # Try to identify number of state features
            if domain in ['SlotFilling', 'CamRest']:

                # Plato does not use action masks (rules to define which
                # actions are valid from each state) and so training can
                # be harder. This becomes easier if we have a smaller
                # action set.

                # Sub-case for CamRest
                if domain == 'CamRest':
                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_sys = [
                        'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye',
                        'reqmore', 'welcomemsg', 'expl-conf', 'select',
                        'repeat', 'confirm-domain', 'confirm'
                    ]

                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_usr = [
                        'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye',
                        'reqmore', 'hello', 'expl-conf', 'repeat', 'reqalts',
                        'restart', 'confirm'
                    ]

                    if self.agent_role == 'system':
                        self.dstc2_acts = self.dstc2_acts_sys

                    elif self.agent_role == 'user':
                        self.dstc2_acts = self.dstc2_acts_usr

                    self.NActions = \
                        len(self.dstc2_acts) + len(self.requestable_slots)

                    if self.agent_role == 'system':
                        self.NActions += len(self.system_requestable_slots)
                    else:
                        self.NActions += len(self.requestable_slots)

    def initialize(self, args):
        """
        Initialize internal parameters

        :return: Nothing
        """

        if 'is_training' in args:
            self.is_training = bool(args['is_training'])

        if 'agent_role' in args:
            self.agent_role = args['agent_role']

    def restart(self, args):
        """
        Nothing to do here.

        :return:
        """

        pass

    def next_action(self, state):
        """
        Consults the dialogue policy to produce the agent's response

        :param state: the current dialogue state
        :return: a list of dialogue acts, representing the agent's response
        """

        state_enc = self.encode_state(state)

        if state_enc not in self.Q or (self.is_training
                                       and random.random() < self.epsilon):

            if random.random() < 0.5:
                # During exploration we may want to follow another dialogue
                # policy, e.g. an expert dialogue policy.

                print('---: Selecting warmup action.')

                if self.agent_role == 'system':
                    return self.warmup_policy.next_action(state)
                else:
                    self.warmup_simulator.receive_input(
                        state.user_acts, state.user_goal)
                    return self.warmup_simulator.respond()

            else:
                # Return a random action
                print('---: Selecting random action')
                return self.decode_action(
                    random.choice(range(0, self.NActions)),
                    self.agent_role == 'system')

        if self.IS_GREEDY_POLICY:
            # Return action with maximum Q value from the given state
            sys_acts = self.decode_action(
                max(self.Q[state_enc], key=self.Q[state_enc].get),
                self.agent_role == 'system')
        else:
            sys_acts = self.decode_action(
                random.choices(range(0, self.NActions), self.Q[state_enc])[0],
                self.agent_role == 'system')

        return sys_acts

    def encode_state(self, state):
        """
        Encodes the dialogue state into an index used to address the Q matrix.

        :param state: the state to encode
        :return: int - a unique state ID
        """

        temp = []

        temp += [int(b) for b in format(state.turn, '06b')]

        for value in state.slots_filled.values():
            # This contains the requested slot
            temp.append(1) if value else temp.append(0)

        for slot in self.ontology.ontology['requestable']:
            temp.append(1) if slot == state.requested_slot else temp.append(0)

        temp.append(int(state.is_terminal_state))

        # If the agent is a system, then this shows what the top db result is.
        # If the agent is a user, then this shows what information the
        # system has provided
        if state.item_in_focus:
            for slot in self.ontology.ontology['requestable']:
                if slot in state.item_in_focus and state.item_in_focus[slot]:
                    temp.append(1)
                else:
                    temp.append(0)
        else:
            temp += [0] * len(self.ontology.ontology['requestable'])

        if state.db_matches_ratio >= 0:
            temp += \
                [int(b) for b in
                 format(int(round(state.db_matches_ratio, 2) * 100), '07b')]
        else:
            # If the number is negative (should not happen in general) there
            # will be a minus sign
            temp += \
                [int(b) for b in
                 format(int(round(state.db_matches_ratio, 2) * 100),
                        '07b')[1:]]

        temp.append(1) if state.system_made_offer else temp.append(0)

        if state.user_acts:
            temp += \
                [int(b) for b in
                 format(self.encode_action(state.user_acts, False), '05b')]
        else:
            temp += [0, 0, 0, 0, 0]

        if state.last_sys_acts:
            temp += \
                [int(b) for b in
                 format(self.encode_action([state.last_sys_acts[0]]), '04b')]
        else:
            temp += [0, 0, 0, 0]

        # If the agent plays the role of the user it needs access to its own
        # goal
        if state.user_goal:
            for c in self.ontology.ontology['informable']:
                if c in state.user_goal.constraints and \
                        state.user_goal.constraints[c].value:
                    temp.append(1)
                else:
                    temp.append(0)

            for r in self.ontology.ontology['requestable']:
                if r in state.user_goal.requests and \
                        state.user_goal.requests[r].value:
                    temp.append(1)
                else:
                    temp.append(0)
        else:
            # Just for symmetry, for all other roles append zeros
            temp += [0] * (len(self.ontology.ontology['informable']) +
                           len(self.ontology.ontology['requestable']))

        # Encode state
        state_enc = 0
        for t in temp:
            state_enc = (state_enc << 1) | t

        return state_enc

    def encode_action(self, actions, system=True):
        """
        Encode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be encoding another agent's action
        (e.g. a system encoding the previous user act).

        :param actions: actions to be encoded
        :param system: whether the role whose action we are encoding is a
                       'system'
        :return: the encoded action
        """

        # TODO: Handle multiple actions
        if not actions:
            print('WARNING: Supervised dialogue policy action encoding called '
                  'with empty actions list (returning -1).')
            return -1

        action = actions[0]

        slot = None
        if action.params and action.params[0].slot:
            slot = action.params[0].slot

        if system:
            if self.dstc2_acts_sys and action.intent in self.dstc2_acts_sys:
                return self.dstc2_acts_sys.index(action.intent)

            if slot:
                if action.intent == 'request' and slot in \
                        self.system_requestable_slots:
                    return len(self.dstc2_acts_sys) + \
                           self.system_requestable_slots.index(slot)

                if action.intent == 'inform' and slot in \
                        self.requestable_slots:
                    return len(self.dstc2_acts_sys) + \
                           len(self.system_requestable_slots) + \
                           self.requestable_slots.index(slot)
        else:
            if self.dstc2_acts_usr and action.intent in self.dstc2_acts_usr:
                return self.dstc2_acts_usr.index(action.intent)

            if slot:
                if action.intent == 'request' and slot in \
                        self.requestable_slots:
                    return len(self.dstc2_acts_usr) + \
                           self.requestable_slots.index(slot)

                if action.intent == 'inform' and slot in \
                        self.requestable_slots:
                    return len(self.dstc2_acts_usr) + \
                           len(self.requestable_slots) + \
                           self.requestable_slots.index(slot)

        # Unable to encode action
        print('Q-Learning ({0}) dialogue policy action encoder warning: '
              'Selecting default action (unable to encode: {1})!'.format(
                  self.agent_role, action))
        return -1

    def decode_action(self, action_enc, system=True):
        """
        Decode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be decoding another agent's action
        (e.g. a system decoding the previous user act).

        :param action_enc: action encoding to be decoded
        :param system: whether the role whose action we are decoding is a
                       'system'
        :return: the decoded action
        """

        if system:
            if action_enc < len(self.dstc2_acts_sys):
                return [DialogueAct(self.dstc2_acts_sys[action_enc], [])]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.system_requestable_slots[
                                action_enc - len(self.dstc2_acts_sys)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots) + \
                    len(self.requestable_slots):
                index = \
                    action_enc - len(self.dstc2_acts_sys) - \
                    len(self.system_requestable_slots)
                return [
                    DialogueAct('inform', [
                        DialogueActItem(self.requestable_slots[index],
                                        Operator.EQ, '')
                    ])
                ]

        else:
            if action_enc < len(self.dstc2_acts_usr):
                return [DialogueAct(self.dstc2_acts_usr[action_enc], [])]

            if action_enc < len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.requestable_slots[action_enc -
                                                   len(self.dstc2_acts_usr)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots):
                return [
                    DialogueAct('inform', [
                        DialogueActItem(
                            self.requestable_slots[
                                action_enc - len(self.dstc2_acts_usr) -
                                len(self.requestable_slots)], Operator.EQ, '')
                    ])
                ]

    def train(self, dialogues):
        """
        Train the model using Q-learning.

        :param dialogues: a list dialogues, which is a list of dialogue turns
                          (state, action, reward triplets).
        :return:
        """

        for dialogue in dialogues:
            if len(dialogue) > 1:
                dialogue[-2]['reward'] = dialogue[-1]['reward']

            for turn in dialogue:
                state_enc = self.encode_state(turn['state'])
                new_state_enc = self.encode_state(turn['new_state'])
                action_enc = self.encode_action(turn['action'])

                if action_enc < 0:
                    continue

                if state_enc not in self.Q:
                    self.Q[state_enc] = {}

                if action_enc not in self.Q[state_enc]:
                    self.Q[state_enc][action_enc] = 0

                max_q = 0
                if new_state_enc in self.Q:
                    max_q = max(self.Q[new_state_enc].values())

                self.Q[state_enc][action_enc] += \
                    self.alpha * (turn['reward'] +
                                  self.gamma * max_q -
                                  self.Q[state_enc][action_enc])

        # Decay learning rate
        if self.alpha > 0.001:
            self.alpha *= self.alpha_decay

        # Decay exploration rate
        if self.epsilon > 0.05:
            self.epsilon *= self.epsilon_decay

        print('Q-Learning: [alpha: {0}, epsilon: {1}]'.format(
            self.alpha, self.epsilon))

    def save(self, path=None):
        """
        Save the Q learning dialogue policy model

        :param path: the path to save the model to
        :return: nothing
        """

        # Don't save if not training
        if not self.is_training:
            return

        if not path:
            path = 'models/policies/q_policy.pkl'
            print('No dialogue policy file name provided. Using default: {0}'.
                  format(path))

        # If the directory does not exist, create it
        if not os.path.exists(os.path.dirname(path)):
            os.makedirs(os.path.dirname(path), exist_ok=True)

        obj = {
            'Q': self.Q,
            'a': self.alpha,
            'e': self.epsilon,
            'g': self.gamma
        }

        with open(path, 'wb') as file:
            pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)

    def load(self, path=None):
        """
        Loads the Q learning dialogue policy model

        :param path: the path to load the model from
        :return: nothing
        """

        if not path:
            print('No dialogue policy loaded.')
            return

        if isinstance(path, str):
            if os.path.isfile(path):
                with open(path, 'rb') as file:
                    obj = pickle.load(file)

                    if 'Q' in obj:
                        self.Q = obj['Q']
                    if 'a' in obj:
                        self.alpha = obj['a']
                    if 'e' in obj:
                        self.epsilon = obj['e']
                    if 'g' in obj:
                        self.gamma = obj['g']

                    print('Q dialogue policy loaded from {0}.'.format(path))

            else:
                print('Warning! Q dialogue policy file %s not found' % path)
        else:
            print('Warning! Unacceptable value for Q dialogue policy file '
                  'name: %s ' % path)