Exemplo n.º 1
0
    def initialize(self, **kwargs):
        """
        Initialize internal structures at the beginning of each dialogue

        :return: Nothing
        """

        if self.agent_role == 'system':
            # Put your system expert policy here
            self.warmup_policy = HandcraftedPolicy(self.ontology)

        elif self.agent_role == 'user':
            usim_args = \
                dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))
            # Put your user expert policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        if 'is_training' in kwargs:
            self.is_training = bool(kwargs['is_training'])

            if self.agent_role == 'user' and self.warmup_simulator:
                if 'goal' in kwargs:
                    self.warmup_simulator.initialize({kwargs['goal']})
                else:
                    print('WARNING ! No goal provided for Supervised policy '
                          'user simulator @ initialize')
                    self.warmup_simulator.initialize({})

        if 'policy_path' in kwargs:
            self.policy_path = kwargs['policy_path']

        if 'learning_rate' in kwargs:
            self.policy_alpha = kwargs['learning_rate']

        if self.sess is None:
            self.policy_net = self.feed_forward_net_init()
            self.sess = tf.InteractiveSession()
            self.sess.run(tf.global_variables_initializer())

            self.tf_saver = \
                tf.train.Saver(var_list=tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES, scope=self.tf_scope))
Exemplo n.º 2
0
    def __init__(self,
                 ontology,
                 database,
                 agent_id=0,
                 agent_role='system',
                 alpha=0.25,
                 gamma=0.95,
                 epsilon=0.25,
                 alpha_decay=0.9995,
                 epsilon_decay=0.995):
        """
        Initialize parameters and internal structures

        :param ontology: the domain's ontology
        :param database: the domain's database
        :param agent_id: the agent's id
        :param agent_role: the agent's role
        :param alpha: the learning rate
        :param gamma: the discount rate
        :param epsilon: the exploration rate
        :param alpha_decay: the learning rate discount rate
        :param epsilon_decay: the exploration rate discount rate
        """

        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.alpha_decay = alpha_decay
        self.epsilon_decay = epsilon_decay

        self.is_training = False

        self.agent_id = agent_id
        self.agent_role = agent_role

        self.ontology = None
        if isinstance(ontology, Ontology):
            self.ontology = ontology
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase):
            self.database = database
        else:
            raise ValueError('MinimaxQ DialoguePolicy: Unacceptable database '
                             'type %s ' % database)

        self.Q = {}
        self.V = {}
        self.pi = {}

        self.pp = pprint.PrettyPrinter(width=160)  # For debug!

        # System and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        if self.agent_role == 'system':
            # Put your system expert policy here
            self.warmup_policy = \
                HandcraftedPolicy.HandcraftedPolicy(self.ontology)

        elif self.agent_role == 'user':
            usim_args = \
                dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))
            # Put your user expert policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        # Sub-case for CamRest
        self.dstc2_acts_sys = self.dstc2_acts_usr = None

        # Plato does not use action masks (rules to define which
        # actions are valid from each state) and so training can
        # be harder. This becomes easier if we have a smaller
        # action set.

        # Does not include inform and request that are modelled together with
        # their arguments
        self.dstc2_acts_sys = [
            'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye', 'reqmore',
            'welcomemsg', 'expl-conf', 'select', 'repeat', 'confirm-domain',
            'confirm'
        ]

        # Does not include inform and request that are modelled together with
        # their arguments
        self.dstc2_acts_usr = [
            'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye', 'reqmore',
            'hello', 'expl-conf', 'repeat', 'reqalts', 'restart', 'confirm'
        ]

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        if self.dstc2_acts_sys:
            if self.agent_role == 'system':
                self.NActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots)

            elif self.agent_role == 'user':
                self.NActions = \
                    len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)
        else:
            if self.agent_role == 'system':
                self.NActions = \
                    5 + \
                    len(self.ontology.ontology['system_requestable']) + \
                    len(self.ontology.ontology['requestable'])

                self.NOtherActions = \
                    4 + 2 * len(self.ontology.ontology['requestable'])

            elif self.agent_role == 'user':
                self.NActions = \
                    4 + 2 * len(self.ontology.ontology['requestable'])
                self.NOtherActions = \
                    5 + len(self.ontology.ontology['system_requestable']) + \
                    len(self.ontology.ontology['requestable'])
Exemplo n.º 3
0
class MinimaxQPolicy(DialoguePolicy.DialoguePolicy):
    def __init__(self,
                 ontology,
                 database,
                 agent_id=0,
                 agent_role='system',
                 alpha=0.25,
                 gamma=0.95,
                 epsilon=0.25,
                 alpha_decay=0.9995,
                 epsilon_decay=0.995):
        """
        Initialize parameters and internal structures

        :param ontology: the domain's ontology
        :param database: the domain's database
        :param agent_id: the agent's id
        :param agent_role: the agent's role
        :param alpha: the learning rate
        :param gamma: the discount rate
        :param epsilon: the exploration rate
        :param alpha_decay: the learning rate discount rate
        :param epsilon_decay: the exploration rate discount rate
        """

        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.alpha_decay = alpha_decay
        self.epsilon_decay = epsilon_decay

        self.is_training = False

        self.agent_id = agent_id
        self.agent_role = agent_role

        self.ontology = None
        if isinstance(ontology, Ontology):
            self.ontology = ontology
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase):
            self.database = database
        else:
            raise ValueError('MinimaxQ DialoguePolicy: Unacceptable database '
                             'type %s ' % database)

        self.Q = {}
        self.V = {}
        self.pi = {}

        self.pp = pprint.PrettyPrinter(width=160)  # For debug!

        # System and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        if self.agent_role == 'system':
            # Put your system expert policy here
            self.warmup_policy = \
                HandcraftedPolicy.HandcraftedPolicy(self.ontology)

        elif self.agent_role == 'user':
            usim_args = \
                dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))
            # Put your user expert policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        # Sub-case for CamRest
        self.dstc2_acts_sys = self.dstc2_acts_usr = None

        # Plato does not use action masks (rules to define which
        # actions are valid from each state) and so training can
        # be harder. This becomes easier if we have a smaller
        # action set.

        # Does not include inform and request that are modelled together with
        # their arguments
        self.dstc2_acts_sys = [
            'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye', 'reqmore',
            'welcomemsg', 'expl-conf', 'select', 'repeat', 'confirm-domain',
            'confirm'
        ]

        # Does not include inform and request that are modelled together with
        # their arguments
        self.dstc2_acts_usr = [
            'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye', 'reqmore',
            'hello', 'expl-conf', 'repeat', 'reqalts', 'restart', 'confirm'
        ]

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        if self.dstc2_acts_sys:
            if self.agent_role == 'system':
                self.NActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots)

            elif self.agent_role == 'user':
                self.NActions = \
                    len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)
        else:
            if self.agent_role == 'system':
                self.NActions = \
                    5 + \
                    len(self.ontology.ontology['system_requestable']) + \
                    len(self.ontology.ontology['requestable'])

                self.NOtherActions = \
                    4 + 2 * len(self.ontology.ontology['requestable'])

            elif self.agent_role == 'user':
                self.NActions = \
                    4 + 2 * len(self.ontology.ontology['requestable'])
                self.NOtherActions = \
                    5 + len(self.ontology.ontology['system_requestable']) + \
                    len(self.ontology.ontology['requestable'])

    def initialize(self, **kwargs):
        """
        Initialize internal parameters

        :return: Nothing
        """

        if 'is_training' in kwargs:
            self.is_training = bool(kwargs['is_training'])

            if self.agent_role == 'user' and self.warmup_simulator:
                if 'goal' in kwargs:
                    self.warmup_simulator.initialize({kwargs['goal']})
                else:
                    print('WARNING ! No goal provided for Supervised policy '
                          'user simulator @ initialize')
                    self.warmup_simulator.initialize({})

    def restart(self, args):
        """
        Re-initialize relevant parameters / variables at the beginning of each
        dialogue.

        :return:
        """

        if self.agent_role == 'user' and self.warmup_simulator:
            if 'goal' in args:
                self.warmup_simulator.initialize(args)
            else:
                print('WARNING! No goal provided for Supervised policy user '
                      'simulator @ restart')
                self.warmup_simulator.initialize({})

    def next_action(self, state):
        """
        Consults the policy to produce the agent's response

        :param state: the current dialogue state
        :return: a list of dialogue acts, representing the agent's response
        """

        state_enc = self.encode_state(state)

        if state_enc not in self.pi or \
                (self.is_training and random.random() < self.epsilon):
            if not self.is_training:
                if not self.pi:
                    print(f'\nWARNING! Minimax Q {self.agent_role} matrix is '
                          f'empty. Did you load the correct file?\n')
                else:
                    print(f'\nWARNING! Minimax {self.agent_role} state not '
                          f'found in Q matrix.\n')

            if random.random() < 0.5:
                print('--- {0}: Selecting warmup action.'.format(
                    self.agent_role))

                if self.agent_role == 'system':
                    return self.warmup_policy.next_action(state)

                else:
                    self.warmup_simulator.receive_input(
                        state.user_acts, state.user_goal)
                    return self.warmup_simulator.respond()

            else:
                print('--- {0}: Selecting random action.'.format(
                    self.agent_role))
                return self.decode_action(
                    random.choice(range(0, self.NActions)),
                    self.agent_role == 'system')

        # Return best action
        max_pi = max(self.pi[state_enc])
        maxima = [i for i, j in enumerate(self.pi[state_enc]) if j == max_pi]

        # Break ties randomly
        if maxima:
            sys_acts = \
                self.decode_action(
                    random.choice(maxima), self.agent_role == 'system')
        else:
            print('--- {0}: Warning! No maximum value identified for policy. '
                  'Selecting random action.'.format(self.agent_role))
            return self.decode_action(random.choice(range(0, self.NActions)),
                                      self.agent_role == 'system')

        return sys_acts

    def encode_state(self, state):
        """
        Encodes the dialogue state into an index used to address the Q matrix.

        :param state: the state to encode
        :return: int - a unique state encoding
        """
        temp = [int(state.is_terminal_state)]

        temp.append(1) if state.system_made_offer else temp.append(0)

        # If the agent plays the role of the user it needs access to its own
        # goal
        if self.agent_role == 'user':
            # The user agent needs to know which constraints and requests need
            # to be communicated and which of them
            # actually have.
            if state.user_goal:
                found_unanswered_constr = False
                found_unanswered_req = False

                for c in self.informable_slots:
                    if c != 'name':
                        if c in state.user_goal.constraints and \
                                c not in state.user_goal.actual_constraints:
                            found_unanswered_constr = True
                            break

                for r in self.requestable_slots:
                    if r in state.user_goal.requests and \
                            not state.user_goal.requests[r].value:
                        found_unanswered_req = True
                        break

                temp += \
                    [int(found_unanswered_constr), int(found_unanswered_req)]
            else:
                temp += [0, 0]

        if self.agent_role == 'system':
            temp.append(int(state.is_terminal()))
            temp.append(int(state.system_made_offer))

            for value in state.slots_filled.values():
                # This contains the requested slot
                temp.append(1) if value else temp.append(0)

            for r in self.requestable_slots:
                temp.append(1) if r == state.requested_slot else temp.append(0)

        state_enc = 0
        for t in temp:
            state_enc = (state_enc << 1) | t

        return state_enc

    def encode_action(self, actions, system=True):
        """
        Encode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be encoding another agent's action
        (e.g. a system encoding the previous user act).

        :param actions: actions to be encoded
        :param system: whether the role whose action we are encoding is a
                       'system'
        :return: the encoded action
        """

        # TODO: Handle multiple actions
        if not actions:
            print('WARNING: MinimaxQ DialoguePolicy action encoding called '
                  'with empty actions list (returning -1).')
            return -1

        action = actions[0]

        if system:
            if self.dstc2_acts_sys and action.intent in self.dstc2_acts_sys:
                return self.dstc2_acts_sys.index(action.intent)

            if action.intent == 'request':
                return len(self.dstc2_acts_sys) + \
                       self.system_requestable_slots.index(
                           action.params[0].slot)

            if action.intent == 'inform':
                return len(self.dstc2_acts_sys) + \
                       len(self.system_requestable_slots) + \
                       self.requestable_slots.index(
                           action.params[0].slot)
        else:
            if self.dstc2_acts_usr and action.intent in self.dstc2_acts_usr:
                return self.dstc2_acts_usr.index(action.intent)

            if action.intent == 'request':
                return len(self.dstc2_acts_usr) + \
                       self.requestable_slots.index(action.params[0].slot)

            if action.intent == 'inform':
                return len(self.dstc2_acts_usr) + \
                       len(self.requestable_slots) + \
                       self.system_requestable_slots.index(
                           action.params[0].slot)

        if (self.agent_role == 'system') == system:
            print('MinimaxQ ({0}) policy action encoder warning: Selecting '
                  'default action (unable to encode: {1})!'.format(
                      self.agent_role, action))
        else:
            print('MinimaxQ ({0}) policy action encoder warning: Selecting '
                  'default action (unable to encode other agent action: '
                  '{1})!'.format(self.agent_role, action))

        return -1

    def decode_action(self, action_enc, system=True):
        """
        Decode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be decoding another agent's action
        (e.g. a system decoding the previous user act).

        :param action_enc: action encoding to be decoded
        :param system: whether the role whose action we are decoding is a
                       'system'
        :return: the decoded action
        """

        if system:
            if action_enc < len(self.dstc2_acts_sys):
                return [DialogueAct(self.dstc2_acts_sys[action_enc], [])]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.system_requestable_slots[
                                action_enc - len(self.dstc2_acts_sys)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots) + \
                    len(self.requestable_slots):
                index = \
                    action_enc - len(self.dstc2_acts_sys) - \
                    len(self.system_requestable_slots)
                return [
                    DialogueAct('inform', [
                        DialogueActItem(self.requestable_slots[index],
                                        Operator.EQ, '')
                    ])
                ]

        else:
            if action_enc < len(self.dstc2_acts_usr):
                return [DialogueAct(self.dstc2_acts_usr[action_enc], [])]

            if action_enc < len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.requestable_slots[action_enc -
                                                   len(self.dstc2_acts_usr)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots):
                return [
                    DialogueAct('inform', [
                        DialogueActItem(
                            self.system_requestable_slots[
                                action_enc - len(self.dstc2_acts_usr) -
                                len(self.requestable_slots)], Operator.EQ, '')
                    ])
                ]

        # Default fall-back action
        print('MinimaxQ DialoguePolicy ({0}) policy action decoder warning: '
              'Selecting repeat() action '
              '(index: {1})!'.format(self.agent_role, action_enc))
        return [DialogueAct('repeat', [])]

    def train(self, dialogues):
        """
        Train the model using MinimaxQ.

        :param dialogues: a list dialogues, which is a list of dialogue turns
                          (state, action, reward triplets).
        :return:
        """

        if not self.is_training:
            return

        for dialogue in dialogues:
            if len(dialogue) > 1:
                dialogue[-2]['reward'] = dialogue[-1]['reward']

            for turn in dialogue:
                state_enc = self.encode_state(turn['state'])
                new_state_enc = self.encode_state(turn['new_state'])
                action_enc = \
                    self.encode_action(
                        turn['action'],
                        self.agent_role == 'system')
                other_action_enc = \
                    self.encode_action(
                        turn['state'].user_acts,
                        self.agent_role != 'system')

                if action_enc < 0 or other_action_enc < 0 or \
                        turn['action'][0].intent == 'bye':
                    continue

                if state_enc not in self.Q:
                    self.Q[state_enc] = []

                    for oa in range(self.NOtherActions):
                        self.Q[state_enc].append([])

                        for a in range(self.NActions):
                            self.Q[state_enc][oa].append(1)

                if state_enc not in self.pi:
                    self.pi[state_enc] = float(1 / self.NActions)

                if action_enc not in self.Q[state_enc][other_action_enc]:
                    self.Q[state_enc][other_action_enc][action_enc] = 0

                if new_state_enc not in self.V:
                    self.V[new_state_enc] = 0

                if new_state_enc not in self.pi:
                    self.pi[new_state_enc] = float(1 / self.NActions)

                delta = turn['reward'] + self.gamma * self.V[new_state_enc]

                # Only update Q values (actor) that lead to an increase in Q
                # if delta > self.Q[state_enc][other_action_enc][action_enc]:
                self.Q[state_enc][other_action_enc][action_enc] += \
                    self.alpha * delta

                # Update V (critic)
                self.V[state_enc] = self.maxmin(state_enc)

        # Decay learning rate after each episode
        if self.alpha > 0.001:
            self.alpha *= self.alpha_decay

        # Decay exploration rate after each episode
        if self.epsilon > 0.25:
            self.epsilon *= self.epsilon_decay

        print('MiniMaxQ [alpha: {0}, epsilon: {1}]'.format(
            self.alpha, self.epsilon))

    def maxmin(self, state_enc, retry=False):
        """
        Solve the maxmin problem

        :param state_enc: the encoding to the state
        :param retry:
        :return:
        """

        c = np.zeros(self.NActions + 1)
        c[0] = -1
        A_ub = np.ones((self.NOtherActions, self.NActions + 1))
        A_ub[:, 1:] = -np.asarray(self.Q[state_enc])
        b_ub = np.zeros(self.NOtherActions)
        A_eq = np.ones((1, self.NActions + 1))
        A_eq[0, 0] = 0
        b_eq = [1]
        bounds = ((None, None), ) + ((0, 1), ) * self.NActions

        res = linprog(c,
                      A_ub=A_ub,
                      b_ub=b_ub,
                      A_eq=A_eq,
                      b_eq=b_eq,
                      bounds=bounds)

        if res.success:
            self.pi[state_enc] = res.x[1:]
        elif not retry:
            return self.maxmin(state_enc, retry=True)
        else:
            print("Alert : %s" % res.message)
            if state_enc in self.V:
                return self.V[state_enc]
            else:
                print('Warning, state not in V, returning 0.')
                return 0

        return res.x[0]

    def save(self, path=None):
        """
        Save the model in the path provided

        :param path: path to dave the model to
        :return: nothing
        """

        # Don't save if not training
        if not self.is_training:
            return

        if not path:
            path = 'Models/Policies/minimax_q_policy.pkl'
            print('No policy file name provided. Using default: {0}'.format(
                path))

        obj = {
            'Q': self.Q,
            'V': self.V,
            'pi': self.pi,
            'a': self.alpha,
            'e': self.epsilon,
            'g': self.gamma
        }

        with open(path, 'wb') as file:
            pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)

    def load(self, path):
        """
        Load the model from the path provided

        :param path: path to load the model from
        :return: nothing
        """

        if not path:
            print('No policy loaded.')
            return

        if isinstance(path, str):
            if os.path.isfile(path):
                with open(path, 'rb') as file:
                    obj = pickle.load(file)

                    if 'Q' in obj:
                        self.Q = obj['Q']
                    if 'V' in obj:
                        self.V = obj['V']
                    if 'pi' in obj:
                        self.pi = obj['pi']
                    if 'a' in obj:
                        self.alpha = obj['a']
                    if 'e' in obj:
                        self.epsilon = obj['e']
                    if 'g' in obj:
                        self.gamma = obj['g']

                    print('Q DialoguePolicy loaded from {0}.'.format(path))

            else:
                print('Warning! Q DialoguePolicy file %s not found' % path)
        else:
            print('Warning! Unacceptable value for Q policy file name: %s ' %
                  path)
Exemplo n.º 4
0
    def __init__(self,
                 ontology,
                 database,
                 agent_id=0,
                 agent_role='system',
                 domain=None,
                 alpha=0.2,
                 epsilon=0.95,
                 gamma=0.95,
                 alpha_decay=0.995,
                 epsilon_decay=0.9995,
                 epsilon_min=0.05):
        """
        Initialize parameters and internal structures

        :param ontology: the domain's ontology
        :param database: the domain's database
        :param agent_id: the agent's id
        :param agent_role: the agent's role
        :param alpha: the learning rate
        :param gamma: the discount rate
        :param epsilon: the exploration rate
        :param alpha_decay: the learning rate discount rate
        :param epsilon_decay: the exploration rate discount rate
        """

        super(ReinforcePolicy, self).__init__()

        self.logger = logging.getLogger(__name__)

        self.agent_id = agent_id
        self.agent_role = agent_role

        self.IS_GREEDY = False

        self.ontology = None
        if isinstance(ontology, Ontology.Ontology):
            self.ontology = ontology
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase.DataBase):
            self.database = database
        else:
            raise ValueError('Reinforce DialoguePolicy: Unacceptable '
                             'database type %s ' % database)

        self.policy_path = None

        self.weights = None
        self.sess = None

        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.alpha_decay_rate = alpha_decay
        self.exploration_decay_rate = epsilon_decay
        self.epsilon_min = epsilon_min

        # System and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        if self.agent_role == 'system':
            # Put your system expert policy here
            self.warmup_policy = HandcraftedPolicy(self.ontology)

        elif self.agent_role == 'user':
            usim_args = \
                dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))
            # Put your user expert policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        self.tf_scope = "policy_" + self.agent_role + '_' + str(self.agent_id)

        # Default value
        self.is_training = True

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        if not domain:
            # Default to CamRest dimensions
            self.NStateFeatures = 56

            # Default to CamRest actions
            self.dstc2_acts = [
                'inform', 'offer', 'request', 'canthelp', 'affirm', 'negate',
                'deny', 'ack', 'thankyou', 'bye', 'reqmore', 'hello',
                'welcomemsg', 'expl-conf', 'select', 'repeat', 'reqalts',
                'confirm-domain', 'confirm'
            ]
        else:
            # Try to identify number of state features
            if domain in ['CamRest', 'SFH', 'SlotFilling']:
                d_state = \
                    SlotFillingDialogueState(
                        {'slots': self.system_requestable_slots})

                # Plato does not use action masks (rules to define which
                # actions are valid from each state) and so training can
                # be harder. This becomes easier if we have a smaller
                # action set.

                # Sub-case for CamRest
                if domain == 'CamRest':
                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_sys = [
                        'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye',
                        'reqmore', 'welcomemsg', 'expl-conf', 'select',
                        'repeat', 'confirm-domain', 'confirm'
                    ]

                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_usr = [
                        'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye',
                        'reqmore', 'hello', 'expl-conf', 'repeat', 'reqalts',
                        'restart', 'confirm'
                    ]

            else:
                self.logger.warning(
                    'Warning! Domain has not been defined. Using '
                    'Slot-Filling Dialogue State')
                d_state = \
                    SlotFillingDialogueState({'slots': self.informable_slots})

            d_state.initialize()
            self.NStateFeatures = len(self.encode_state(d_state))

            self.logger.info(
                'Reinforce DialoguePolicy {0} automatically determined '
                'number of state features: {1}'.format(self.agent_role,
                                                       self.NStateFeatures))

        if domain == 'CamRest' and self.dstc2_acts_sys:
            if self.agent_role == 'system':
                self.NActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots)

            elif self.agent_role == 'user':
                self.NActions = \
                    len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

        else:
            if self.agent_role == 'system':
                self.NActions = \
                    3 + len(self.system_requestable_slots) + \
                    len(self.requestable_slots)

                self.NOtherActions = \
                    2 + len(self.requestable_slots) +\
                    len(self.requestable_slots)

            elif self.agent_role == 'user':
                self.NActions = \
                    2 + len(self.requestable_slots) + \
                    len(self.requestable_slots)

                self.NOtherActions = \
                    3 + len(self.system_requestable_slots) + \
                    len(self.requestable_slots)

        self.logger.info(
            'Reinforce {0} DialoguePolicy Number of Actions: {1}'.format(
                self.agent_role, self.NActions))
Exemplo n.º 5
0
class ReinforcePolicy(DialoguePolicy.DialoguePolicy):
    def __init__(self,
                 ontology,
                 database,
                 agent_id=0,
                 agent_role='system',
                 domain=None,
                 alpha=0.2,
                 epsilon=0.95,
                 gamma=0.95,
                 alpha_decay=0.995,
                 epsilon_decay=0.9995,
                 epsilon_min=0.05):
        """
        Initialize parameters and internal structures

        :param ontology: the domain's ontology
        :param database: the domain's database
        :param agent_id: the agent's id
        :param agent_role: the agent's role
        :param alpha: the learning rate
        :param gamma: the discount rate
        :param epsilon: the exploration rate
        :param alpha_decay: the learning rate discount rate
        :param epsilon_decay: the exploration rate discount rate
        """

        super(ReinforcePolicy, self).__init__()

        self.logger = logging.getLogger(__name__)

        self.agent_id = agent_id
        self.agent_role = agent_role

        self.IS_GREEDY = False

        self.ontology = None
        if isinstance(ontology, Ontology.Ontology):
            self.ontology = ontology
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase.DataBase):
            self.database = database
        else:
            raise ValueError('Reinforce DialoguePolicy: Unacceptable '
                             'database type %s ' % database)

        self.policy_path = None

        self.weights = None
        self.sess = None

        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.alpha_decay_rate = alpha_decay
        self.exploration_decay_rate = epsilon_decay
        self.epsilon_min = epsilon_min

        # System and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        if self.agent_role == 'system':
            # Put your system expert policy here
            self.warmup_policy = HandcraftedPolicy(self.ontology)

        elif self.agent_role == 'user':
            usim_args = \
                dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))
            # Put your user expert policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        self.tf_scope = "policy_" + self.agent_role + '_' + str(self.agent_id)

        # Default value
        self.is_training = True

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        if not domain:
            # Default to CamRest dimensions
            self.NStateFeatures = 56

            # Default to CamRest actions
            self.dstc2_acts = [
                'inform', 'offer', 'request', 'canthelp', 'affirm', 'negate',
                'deny', 'ack', 'thankyou', 'bye', 'reqmore', 'hello',
                'welcomemsg', 'expl-conf', 'select', 'repeat', 'reqalts',
                'confirm-domain', 'confirm'
            ]
        else:
            # Try to identify number of state features
            if domain in ['CamRest', 'SFH', 'SlotFilling']:
                d_state = \
                    SlotFillingDialogueState(
                        {'slots': self.system_requestable_slots})

                # Plato does not use action masks (rules to define which
                # actions are valid from each state) and so training can
                # be harder. This becomes easier if we have a smaller
                # action set.

                # Sub-case for CamRest
                if domain == 'CamRest':
                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_sys = [
                        'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye',
                        'reqmore', 'welcomemsg', 'expl-conf', 'select',
                        'repeat', 'confirm-domain', 'confirm'
                    ]

                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_usr = [
                        'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye',
                        'reqmore', 'hello', 'expl-conf', 'repeat', 'reqalts',
                        'restart', 'confirm'
                    ]

            else:
                self.logger.warning(
                    'Warning! Domain has not been defined. Using '
                    'Slot-Filling Dialogue State')
                d_state = \
                    SlotFillingDialogueState({'slots': self.informable_slots})

            d_state.initialize()
            self.NStateFeatures = len(self.encode_state(d_state))

            self.logger.info(
                'Reinforce DialoguePolicy {0} automatically determined '
                'number of state features: {1}'.format(self.agent_role,
                                                       self.NStateFeatures))

        if domain == 'CamRest' and self.dstc2_acts_sys:
            if self.agent_role == 'system':
                self.NActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots)

            elif self.agent_role == 'user':
                self.NActions = \
                    len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

        else:
            if self.agent_role == 'system':
                self.NActions = \
                    3 + len(self.system_requestable_slots) + \
                    len(self.requestable_slots)

                self.NOtherActions = \
                    2 + len(self.requestable_slots) +\
                    len(self.requestable_slots)

            elif self.agent_role == 'user':
                self.NActions = \
                    2 + len(self.requestable_slots) + \
                    len(self.requestable_slots)

                self.NOtherActions = \
                    3 + len(self.system_requestable_slots) + \
                    len(self.requestable_slots)

        self.logger.info(
            'Reinforce {0} DialoguePolicy Number of Actions: {1}'.format(
                self.agent_role, self.NActions))

    def initialize(self, **kwargs):
        """
        Initialize internal structures at the beginning of each dialogue

        :return: Nothing
        """

        if 'is_training' in kwargs:
            self.is_training = bool(kwargs['is_training'])

            if self.agent_role == 'user' and self.warmup_simulator:
                if 'goal' in kwargs:
                    self.warmup_simulator.initialize({kwargs['goal']})
                else:
                    self.logger.warning(
                        'WARNING ! No goal provided for Reinforce policy '
                        'user simulator @ initialize')
                    self.warmup_simulator.initialize({})

        if 'policy_path' in kwargs:
            self.policy_path = kwargs['policy_path']

        if 'learning_rate' in kwargs:
            self.alpha = kwargs['learning_rate']

        if 'learning_decay_rate' in kwargs:
            self.alpha_decay_rate = kwargs['learning_decay_rate']

        if 'discount_factor' in kwargs:
            self.gamma = kwargs['discount_factor']

        if 'exploration_rate' in kwargs:
            self.alpha = kwargs['exploration_rate']

        if 'exploration_decay_rate' in kwargs:
            self.exploration_decay_rate = kwargs['exploration_decay_rate']

        if self.weights is None:
            self.weights = np.random.rand(self.NStateFeatures, self.NActions)

    def restart(self, args):
        """
        Re-initialize relevant parameters / variables at the beginning of each
        dialogue.

        :return: nothing
        """

        if self.agent_role == 'user' and self.warmup_simulator:
            if 'goal' in args:
                self.warmup_simulator.initialize(args)
            else:
                self.logger.warning(
                    'WARNING! No goal provided for Reinforce policy user '
                    'simulator @ restart')
                self.warmup_simulator.initialize({})

    def next_action(self, state):
        """
        Consults the policy to produce the agent's response

        :param state: the current dialogue state
        :return: a list of dialogue acts, representing the agent's response
        """

        if self.is_training and random.random() < self.epsilon:
            if random.random() < 0.75:
                self.logger.debug('--- {0}: Selecting warmup action.'.format(
                    self.agent_role))

                if self.agent_role == 'system':
                    return self.warmup_policy.next_action(state)

                else:
                    self.warmup_simulator.receive_input(
                        state.user_acts, state.user_goal)
                    return self.warmup_simulator.respond()

            else:
                self.logger.debug('--- {0}: Selecting random action.'.format(
                    self.agent_role))
                return self.decode_action(
                    random.choice(range(0, self.NActions)),
                    self.agent_role == "system")

        # Probabilistic policy: Sample from action wrt probabilities
        probs = self.calculate_policy(self.encode_state(state))

        if any(np.isnan(probs)):
            self.logger.warning(
                'WARNING! NAN detected in action probabilities! Selecting '
                'random action.')
            return self.decode_action(random.choice(range(0, self.NActions)),
                                      self.agent_role == "system")

        if self.IS_GREEDY:
            # Get greedy action
            max_pi = max(probs)
            maxima = [i for i, j in enumerate(probs) if j == max_pi]

            # Break ties randomly
            if maxima:
                sys_acts = \
                    self.decode_action(
                        random.choice(maxima), self.agent_role == 'system')
            else:
                self.logger.warning(
                    f'--- {self.agent_role}: Warning! No maximum value '
                    f'identified for policy. Selecting random action.')
                return self.decode_action(
                    random.choice(range(0, self.NActions)),
                    self.agent_role == 'system')
        else:
            # Pick from top 3 actions
            top_3 = np.argsort(-probs)[0:2]
            sys_acts = \
                self.decode_action(
                    random.choices(
                        top_3, probs[top_3])[0], self.agent_role == 'system')

        return sys_acts

    @staticmethod
    def softmax(x):
        """
        Calculates the softmax of x

        :param x: a number
        :return: the softmax of the number
        """
        e_x = np.exp(x - np.max(x))
        out = e_x / e_x.sum()
        return out

    @staticmethod
    def softmax_gradient(x):
        """
        Calculates the gradient of the softmax

        :param x: a number
        :return: the gradient of the softmax
        """
        x = np.asarray(x)
        x_reshaped = x.reshape(-1, 1)
        return np.diagflat(x_reshaped) - np.dot(x_reshaped, x_reshaped.T)

    def calculate_policy(self, state):
        """
        Calculates the probabilities for each action from the given state

        :param state: the current dialogue state
        :return: probabilities of actions
        """
        dot_prod = np.dot(state, self.weights)
        exp_dot_prod = np.exp(dot_prod)
        return exp_dot_prod / np.sum(exp_dot_prod)

    def train(self, dialogues):
        """
        Train the policy network

        :param dialogues: dialogue experience
        :return: nothing
        """
        # If called by accident
        if not self.is_training:
            return

        for dialogue in dialogues:
            discount = self.gamma

            if len(dialogue) > 1:
                dialogue[-2]['reward'] = dialogue[-1]['reward']

            rewards = [t['reward'] for t in dialogue]
            norm_rewards = \
                (rewards - np.mean(rewards)) / (np.std(rewards) + 0.000001)

            for (t, turn) in enumerate(dialogue):
                act_enc = self.encode_action(turn['action'],
                                             self.agent_role == 'system')
                if act_enc < 0:
                    continue

                state_enc = self.encode_state(turn['state'])

                if len(state_enc) != self.NStateFeatures:
                    raise ValueError(f'Reinforce DialoguePolicy '
                                     f'{self.agent_role} mismatch in state'
                                     f'dimensions: State Features: '
                                     f'{self.NStateFeatures} != State '
                                     f'Encoding Length: {len(state_enc)}')

                # Calculate the gradients

                # Call policy again to retrieve the probability of the
                # action taken
                probabilities = self.calculate_policy(state_enc)

                softmax_deriv = self.softmax_gradient(probabilities)[act_enc]
                log_policy_grad = softmax_deriv / probabilities[act_enc]
                gradient = \
                    np.asarray(
                        state_enc)[None, :].transpose().dot(
                        log_policy_grad[None, :])
                gradient = np.clip(gradient, -1.0, 1.0)

                # Train policy
                self.weights += \
                    self.alpha * gradient * norm_rewards[t] * discount
                self.weights = np.clip(self.weights, -1, 1)

                discount *= self.gamma

        if self.alpha > 0.01:
            self.alpha *= self.alpha_decay_rate

        self.decay_epsilon()

        self.logger.info(
            f'REINFORCE train, alpha: {self.alpha}, epsilon: {self.epsilon}')

    def decay_epsilon(self):
        """
        Decays epsilon (exploration rate) by epsilon decay.

         Decays epsilon (exploration rate) by epsilon decay.
         If epsilon is already less or equal compared to epsilon_min,
         the call of this method has no effect.

        :return:
        """
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.exploration_decay_rate

    def encode_state(self, state):
        """
        Encodes the dialogue state into a vector.

        :param state: the state to encode
        :return: int - a unique state encoding
        """

        temp = [int(state.is_terminal_state), int(state.system_made_offer)]

        if self.agent_role == 'user':
            # The user agent needs to know which constraints and requests
            # need to be communicated and which of them
            # actually have.
            if state.user_goal:
                for c in self.informable_slots:
                    if c != 'name':
                        if c in state.user_goal.constraints:
                            temp.append(1)
                        else:
                            temp.append(0)

                for c in self.informable_slots:
                    if c != 'name':
                        if c in state.user_goal.actual_constraints and \
                                state.user_goal.actual_constraints[c].value:
                            temp.append(1)
                        else:
                            temp.append(0)

                for r in self.requestable_slots:
                    if r in state.user_goal.requests:
                        temp.append(1)
                    else:
                        temp.append(0)

                for r in self.requestable_slots:

                    if r in state.user_goal.actual_requests and \
                            state.user_goal.actual_requests[r].value:
                        temp.append(1)
                    else:
                        temp.append(0)

            else:
                temp += [0] * 2 * (len(self.informable_slots) - 1 +
                                   len(self.requestable_slots))

        if self.agent_role == 'system':
            for value in state.slots_filled.values():
                # This contains the requested slot
                temp.append(1) if value else temp.append(0)

            for r in self.requestable_slots:
                temp.append(1) if r in state.requested_slots else temp.append(
                    0)

        return temp

    def encode_action(self, actions, system=True):
        """
        Encode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be encoding another agent's action
        (e.g. a system encoding the previous user act).

        :param actions: actions to be encoded
        :param system: whether the role whose action we are encoding is a
                       'system'
        :return: the encoded action
        """

        # TODO: Handle multiple actions
        # TODO: Action encoding in a principled way
        if not actions:
            self.logger.warning(
                'WARNING: Reinforce DialoguePolicy action encoding called '
                'with empty actions list (returning 0).')
            return -1

        action = actions[0]

        if system:
            if self.dstc2_acts_sys and action.intent in self.dstc2_acts_sys:
                return self.dstc2_acts_sys.index(action.intent)

            if action.intent == 'request':
                return len(self.dstc2_acts_sys) + \
                       self.system_requestable_slots.index(
                           action.params[0].slot)

            if action.intent == 'inform':
                return len(self.dstc2_acts_sys) + \
                       len(self.system_requestable_slots) + \
                       self.requestable_slots.index(action.params[0].slot)
        else:
            if self.dstc2_acts_usr and action.intent in self.dstc2_acts_usr:
                return self.dstc2_acts_usr.index(action.intent)

            if action.intent == 'request':
                return len(self.dstc2_acts_usr) + \
                       self.requestable_slots.index(action.params[0].slot)

            if action.intent == 'inform':
                return len(self.dstc2_acts_usr) + \
                       len(self.requestable_slots) + \
                       self.system_requestable_slots.index(action.params[0].slot)

        # Default fall-back action
        self.logger.warning(
            'Reinforce ({0}) policy action encoder warning: Selecting '
            'default action (unable to encode: {1})!'.format(
                self.agent_role, action))
        return -1

    def decode_action(self, action_enc, system=True):
        """
        Decode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be decoding another agent's action
        (e.g. a system decoding the previous user act).

        :param action_enc: action encoding to be decoded
        :param system: whether the role whose action we are decoding is a
                       'system'
        :return: the decoded action
        """

        if system:
            if action_enc < len(self.dstc2_acts_sys):
                return [DialogueAct(self.dstc2_acts_sys[action_enc], [])]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.system_requestable_slots[
                                action_enc - len(self.dstc2_acts_sys)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots) + \
                    len(self.requestable_slots):
                index = action_enc - len(self.dstc2_acts_sys) - \
                        len(self.system_requestable_slots)
                return [
                    DialogueAct('inform', [
                        DialogueActItem(self.requestable_slots[index],
                                        Operator.EQ, '')
                    ])
                ]

        else:
            if action_enc < len(self.dstc2_acts_usr):
                return [DialogueAct(self.dstc2_acts_usr[action_enc], [])]

            if action_enc < len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.requestable_slots[action_enc -
                                                   len(self.dstc2_acts_usr)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots):
                return [
                    DialogueAct('inform', [
                        DialogueActItem(
                            self.system_requestable_slots[
                                action_enc - len(self.dstc2_acts_usr) -
                                len(self.requestable_slots)], Operator.EQ, '')
                    ])
                ]

        # Default fall-back action
        self.logger.warning(
            'Reinforce DialoguePolicy ({0}) policy action decoder warning: '
            'Selecting default action (index: {1})!'.format(
                self.agent_role, action_enc))
        return [DialogueAct('bye', [])]

    def save(self, path=None):
        """
        Saves the policy model to the provided path

        :param path: path to save the model to
        :return:
        """

        # Don't save if not training
        if not self.is_training:
            return

        if not path:
            path = 'Models/Policies/reinforce.pkl'
            self.logger.warning(
                'No policy file name provided. Using default: {0}'.format(
                    path))

        obj = {
            'weights': self.weights,
            'alpha': self.alpha,
            'alpha_decay_rate': self.alpha_decay_rate,
            'epsilon': self.epsilon,
            'exploration_decay_rate': self.exploration_decay_rate,
            'epsilon_min': self.epsilon_min
        }

        with open(path, 'wb') as file:
            pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)

    def load(self, path=None):
        """
        Load the policy model from the provided path

        :param path: path to load the model from
        :return:
        """

        if not path:
            self.logger.warning('No policy loaded.')
            return

        if isinstance(path, str):
            if os.path.isfile(path):
                with open(path, 'rb') as file:
                    obj = pickle.load(file)

                    if 'weights' in obj:
                        self.weights = obj['weights']

                    if 'alpha' in obj:
                        self.alpha = obj['alpha']

                    if 'alpha_decay_rate' in obj:
                        self.alpha_decay_rate = obj['alpha_decay_rate']

                    if 'epsilon' in obj:
                        self.epsilon = obj['epsilon']

                    if 'exploration_decay_rate' in obj:
                        self.exploration_decay_rate = \
                            obj['exploration_decay_rate']

                    if 'epsilon_min' in obj:
                        self.epsilon_min = obj['epsilon_min']

                    self.logger.info(
                        'Reinforce DialoguePolicy loaded from {0}.'.format(
                            path))

            else:
                self.logger.warning(
                    'Warning! Reinforce DialoguePolicy file %s not found' %
                    path)
        else:
            self.logger.warning(
                'Warning! Unacceptable value for Reinforce DialoguePolicy '
                'file name: %s ' % path)
    def __init__(self, configuration):
        """
        Initialize the internal structures of this agent.

        :param configuration: a dictionary representing the configuration file
        :param agent_id: an integer, this agent's id
        """

        super(ConversationalSingleAgent, self).__init__()

        self.configuration = configuration

        # There is only one agent in this setting
        self.agent_id = 0

        # Dialogue statistics
        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0
        self.total_dialogue_turns = 0

        self.minibatch_length = 500
        self.train_interval = 50
        self.train_epochs = 10

        # True values here would imply some default modules
        self.USE_USR_SIMULATOR = False
        self.USER_SIMULATOR_NLU = False
        self.USER_SIMULATOR_NLG = False
        self.USE_NLG = False
        self.USE_SPEECH = False
        self.USER_HAS_INITIATIVE = True
        self.SAVE_LOG = True

        # The dialogue will terminate after MAX_TURNS (this agent will issue
        # a bye() dialogue act.
        self.MAX_TURNS = 15

        self.dialogue_turn = -1
        self.ontology = None
        self.database = None
        self.domain = None
        self.dialogue_manager = None
        self.user_model = None
        self.user_simulator = None
        self.user_simulator_args = {}
        self.nlu = None
        self.nlg = None

        self.agent_role = None
        self.agent_goal = None
        self.goal_generator = None

        self.curr_state = None
        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

        self.user_model = UserModel()

        self.recorder = DialogueEpisodeRecorder()

        # TODO: Handle this properly - get reward function type from config
        self.reward_func = SlotFillingReward()
        # self.reward_func = SlotFillingGoalAdvancementReward()

        if self.configuration:
            # Error checks for options the config must have
            if not self.configuration['GENERAL']:
                raise ValueError('Cannot run Plato without GENERAL settings!')

            elif not self.configuration['GENERAL']['interaction_mode']:
                raise ValueError('Cannot run Plato without an '
                                 'interaction mode!')

            elif not self.configuration['DIALOGUE']:
                raise ValueError('Cannot run Plato without DIALOGUE settings!')

            elif not self.configuration['AGENT_0']:
                raise ValueError('Cannot run Plato without at least '
                                 'one agent!')

            # Dialogue domain self.settings
            if 'DIALOGUE' in self.configuration and \
                    self.configuration['DIALOGUE']:
                if 'initiative' in self.configuration['DIALOGUE']:
                    self.USER_HAS_INITIATIVE = bool(
                        self.configuration['DIALOGUE']['initiative'] == 'user')
                    self.user_simulator_args['us_has_initiative'] = \
                        self.USER_HAS_INITIATIVE

                if self.configuration['DIALOGUE']['domain']:
                    self.domain = self.configuration['DIALOGUE']['domain']

                if self.configuration['DIALOGUE']['ontology_path']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['ontology_path']):
                        self.ontology = Ontology.Ontology(
                            self.configuration['DIALOGUE']['ontology_path'])
                    else:
                        raise FileNotFoundError(
                            'Domain file %s not found' %
                            self.configuration['DIALOGUE']['ontology_path'])

                if self.configuration['DIALOGUE']['db_path']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['db_path']):
                        if 'db_type' in self.configuration['DIALOGUE']:
                            if self.configuration['DIALOGUE']['db_type'] == \
                                    'sql':
                                self.database = DataBase.SQLDataBase(
                                    self.configuration['DIALOGUE']['db_path'])
                            else:
                                self.database = DataBase.DataBase(
                                    self.configuration['DIALOGUE']['db_path'])
                        else:
                            # Default to SQL
                            self.database = DataBase.SQLDataBase(
                                self.configuration['DIALOGUE']['db_path'])
                    else:
                        raise FileNotFoundError(
                            'Database file %s not found' %
                            self.configuration['DIALOGUE']['db_path'])

                if 'goals_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['goals_path']):
                        self.goals_path = \
                            self.configuration['DIALOGUE']['goals_path']
                    else:
                        raise FileNotFoundError(
                            'Goals file %s not found' %
                            self.configuration['DIALOGUE']['goals_path'])

            # General settings
            if 'GENERAL' in self.configuration and \
                    self.configuration['GENERAL']:
                if 'experience_logs' in self.configuration['GENERAL']:
                    dialogues_path = None
                    if 'path' in \
                            self.configuration['GENERAL']['experience_logs']:
                        dialogues_path = \
                            self.configuration['GENERAL'][
                                'experience_logs']['path']

                    if 'load' in \
                            self.configuration['GENERAL']['experience_logs'] \
                        and bool(
                            self.configuration['GENERAL'][
                                'experience_logs']['load']
                    ):
                        if dialogues_path and os.path.isfile(dialogues_path):
                            self.recorder.load(dialogues_path)
                        else:
                            raise FileNotFoundError(
                                'Dialogue Log file %s not found (did you '
                                'provide one?)' % dialogues_path)

                    if 'save' in \
                            self.configuration['GENERAL']['experience_logs']:
                        self.recorder.set_path(dialogues_path)
                        self.SAVE_LOG = bool(self.configuration['GENERAL']
                                             ['experience_logs']['save'])

                if self.configuration['GENERAL']['interaction_mode'] == \
                        'simulation':
                    self.USE_USR_SIMULATOR = True

                elif self.configuration['GENERAL']['interaction_mode'] == \
                        'speech':
                    self.USE_SPEECH = True
                    self.asr = speech_rec.Recognizer()

            # Agent Settings

            # Usr Simulator
            # Check for specific simulator self.settings, otherwise
            # default to agenda
            if 'USER_SIMULATOR' in self.configuration['AGENT_0']:
                # Agent 0 simulator configuration
                a0_sim_config = self.configuration['AGENT_0']['USER_SIMULATOR']
                if a0_sim_config and a0_sim_config['simulator']:
                    # Default settings
                    self.user_simulator_args['ontology'] = self.ontology
                    self.user_simulator_args['database'] = self.database
                    self.user_simulator_args['um'] = self.user_model
                    self.user_simulator_args['patience'] = 5

                    if a0_sim_config['simulator'] == 'agenda':
                        if 'patience' in a0_sim_config:
                            self.user_simulator_args['patience'] = \
                                int(a0_sim_config['patience'])

                        if 'pop_distribution' in a0_sim_config:
                            if isinstance(a0_sim_config['pop_distribution'],
                                          list):
                                self.user_simulator_args['pop_distribution'] =\
                                    a0_sim_config['pop_distribution']
                            else:
                                self.user_simulator_args['pop_distribution'] =\
                                    eval(a0_sim_config['pop_distribution'])

                        if 'slot_confuse_prob' in a0_sim_config:
                            self.user_simulator_args['slot_confuse_prob'] = \
                                float(a0_sim_config['slot_confuse_prob'])
                        if 'op_confuse_prob' in a0_sim_config:
                            self.user_simulator_args['op_confuse_prob'] = \
                                float(a0_sim_config['op_confuse_prob'])
                        if 'value_confuse_prob' in a0_sim_config:
                            self.user_simulator_args['value_confuse_prob'] = \
                                float(a0_sim_config['value_confuse_prob'])

                        if 'goal_slot_selection_weights' in a0_sim_config:
                            self.user_simulator_args[
                                'goal_slot_selection_weights'] = a0_sim_config[
                                    'goal_slot_selection_weights']

                        if 'nlu' in a0_sim_config:
                            self.user_simulator_args['nlu'] = \
                                a0_sim_config['nlu']

                            if self.user_simulator_args['nlu'] == 'dummy':
                                self.user_simulator_args['database'] = \
                                    self.database

                            self.USER_SIMULATOR_NLU = True

                        if 'nlg' in a0_sim_config:
                            self.user_simulator_args['nlg'] = \
                                a0_sim_config['nlg']

                            if self.user_simulator_args['nlg'] == 'CamRest':
                                if a0_sim_config:
                                    self.user_simulator_args[
                                        'nlg_model_path'] = a0_sim_config[
                                            'nlg_model_path']

                                    self.USER_SIMULATOR_NLG = True

                                else:
                                    raise ValueError(
                                        'Usr Simulator NLG: Cannot find '
                                        'model_path in the config.')

                            elif self.user_simulator_args['nlg'] == 'dummy':
                                self.USER_SIMULATOR_NLG = True

                        if 'goals_file' in a0_sim_config:
                            self.user_simulator_args['goals_file'] = \
                                a0_sim_config['goals_file']

                        if 'policy_file' in a0_sim_config:
                            self.user_simulator_args['policy_file'] = \
                                a0_sim_config['policy_file']

                        self.user_simulator = AgendaBasedUS(
                            self.user_simulator_args)

                    elif a0_sim_config['simulator'] == 'dtl':
                        if 'policy_file' in a0_sim_config:
                            self.user_simulator_args['policy_file'] = \
                                a0_sim_config['policy_file']
                            self.user_simulator = DTLUserSimulator(
                                self.user_simulator_args)
                        else:
                            raise ValueError(
                                'Error! Cannot start DAct-to-Language '
                                'simulator without a policy file!')

                else:
                    # Fallback to agenda based simulator with default settings
                    self.user_simulator = AgendaBasedUS(
                        self.user_simulator_args)

            # NLU Settings
            if 'NLU' in self.configuration['AGENT_0'] and \
                    self.configuration['AGENT_0']['NLU'] and \
                    self.configuration['AGENT_0']['NLU']['nlu']:
                nlu_args = dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))

                if self.configuration['AGENT_0']['NLU']['nlu'] == 'dummy':
                    self.nlu = DummyNLU(nlu_args)

                elif self.configuration['AGENT_0']['NLU']['nlu'] == 'CamRest':
                    if self.configuration['AGENT_0']['NLU']['model_path']:
                        nlu_args['model_path'] = \
                            self.configuration['AGENT_0']['NLU']['model_path']
                        self.nlu = CamRestNLU(nlu_args)
                    else:
                        raise ValueError(
                            'Cannot find model_path in the config.')

            # NLG Settings
            if 'NLG' in self.configuration['AGENT_0'] and \
                    self.configuration['AGENT_0']['NLG'] and \
                    self.configuration['AGENT_0']['NLG']['nlg']:
                if self.configuration['AGENT_0']['NLG']['nlg'] == 'dummy':
                    self.nlg = DummyNLG()

                elif self.configuration['AGENT_0']['NLG']['nlg'] == 'CamRest':
                    if self.configuration['AGENT_0']['NLG']['model_path']:
                        self.nlg = CamRestNLG({
                            'model_path':
                            self.configuration['AGENT_0']['NLG']['model_path']
                        })
                    else:
                        raise ValueError(
                            'Cannot find model_path in the config.')

                if self.nlg:
                    self.USE_NLG = True

            # Retrieve agent role
            if 'role' in self.configuration['AGENT_0']:
                self.agent_role = self.configuration['AGENT_0']['role']
            else:
                raise ValueError(
                    'ConversationalAgent: No role assigned for agent {0} in '
                    'config!'.format(self.agent_id))

            if self.agent_role == 'user':
                if self.ontology and self.database:
                    self.goal_generator = GoalGenerator(ontology=self.ontology,
                                                        database=self.database)
                else:
                    raise ValueError(
                        'Conversational Multi Agent (user): Cannot generate '
                        'goal without ontology and database.')

        dm_args = dict(
            zip([
                'settings', 'ontology', 'database', 'domain', 'agent_id',
                'agent_role'
            ], [
                self.configuration, self.ontology, self.database, self.domain,
                self.agent_id, self.agent_role
            ]))
        dm_args.update(self.configuration['AGENT_0']['DM'])
        self.dialogue_manager = DialogueManager.DialogueManager(dm_args)
class ConversationalSingleAgent(ConversationalAgent):
    """
    Essentially the dialogue system. Will be able to interact with:

    - Simulated Users via:
        - Dialogue Acts
        - Text

    - Human Users via:
        - Text
        - Speech
        - Online crowd?

    - Data
    """
    def __init__(self, configuration):
        """
        Initialize the internal structures of this agent.

        :param configuration: a dictionary representing the configuration file
        :param agent_id: an integer, this agent's id
        """

        super(ConversationalSingleAgent, self).__init__()

        self.configuration = configuration

        # There is only one agent in this setting
        self.agent_id = 0

        # Dialogue statistics
        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0
        self.total_dialogue_turns = 0

        self.minibatch_length = 500
        self.train_interval = 50
        self.train_epochs = 10

        # True values here would imply some default modules
        self.USE_USR_SIMULATOR = False
        self.USER_SIMULATOR_NLU = False
        self.USER_SIMULATOR_NLG = False
        self.USE_NLG = False
        self.USE_SPEECH = False
        self.USER_HAS_INITIATIVE = True
        self.SAVE_LOG = True

        # The dialogue will terminate after MAX_TURNS (this agent will issue
        # a bye() dialogue act.
        self.MAX_TURNS = 15

        self.dialogue_turn = -1
        self.ontology = None
        self.database = None
        self.domain = None
        self.dialogue_manager = None
        self.user_model = None
        self.user_simulator = None
        self.user_simulator_args = {}
        self.nlu = None
        self.nlg = None

        self.agent_role = None
        self.agent_goal = None
        self.goal_generator = None

        self.curr_state = None
        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

        self.user_model = UserModel()

        self.recorder = DialogueEpisodeRecorder()

        # TODO: Handle this properly - get reward function type from config
        self.reward_func = SlotFillingReward()
        # self.reward_func = SlotFillingGoalAdvancementReward()

        if self.configuration:
            # Error checks for options the config must have
            if not self.configuration['GENERAL']:
                raise ValueError('Cannot run Plato without GENERAL settings!')

            elif not self.configuration['GENERAL']['interaction_mode']:
                raise ValueError('Cannot run Plato without an '
                                 'interaction mode!')

            elif not self.configuration['DIALOGUE']:
                raise ValueError('Cannot run Plato without DIALOGUE settings!')

            elif not self.configuration['AGENT_0']:
                raise ValueError('Cannot run Plato without at least '
                                 'one agent!')

            # Dialogue domain self.settings
            if 'DIALOGUE' in self.configuration and \
                    self.configuration['DIALOGUE']:
                if 'initiative' in self.configuration['DIALOGUE']:
                    self.USER_HAS_INITIATIVE = bool(
                        self.configuration['DIALOGUE']['initiative'] == 'user')
                    self.user_simulator_args['us_has_initiative'] = \
                        self.USER_HAS_INITIATIVE

                if self.configuration['DIALOGUE']['domain']:
                    self.domain = self.configuration['DIALOGUE']['domain']

                if self.configuration['DIALOGUE']['ontology_path']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['ontology_path']):
                        self.ontology = Ontology.Ontology(
                            self.configuration['DIALOGUE']['ontology_path'])
                    else:
                        raise FileNotFoundError(
                            'Domain file %s not found' %
                            self.configuration['DIALOGUE']['ontology_path'])

                if self.configuration['DIALOGUE']['db_path']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['db_path']):
                        if 'db_type' in self.configuration['DIALOGUE']:
                            if self.configuration['DIALOGUE']['db_type'] == \
                                    'sql':
                                self.database = DataBase.SQLDataBase(
                                    self.configuration['DIALOGUE']['db_path'])
                            else:
                                self.database = DataBase.DataBase(
                                    self.configuration['DIALOGUE']['db_path'])
                        else:
                            # Default to SQL
                            self.database = DataBase.SQLDataBase(
                                self.configuration['DIALOGUE']['db_path'])
                    else:
                        raise FileNotFoundError(
                            'Database file %s not found' %
                            self.configuration['DIALOGUE']['db_path'])

                if 'goals_path' in self.configuration['DIALOGUE']:
                    if os.path.isfile(
                            self.configuration['DIALOGUE']['goals_path']):
                        self.goals_path = \
                            self.configuration['DIALOGUE']['goals_path']
                    else:
                        raise FileNotFoundError(
                            'Goals file %s not found' %
                            self.configuration['DIALOGUE']['goals_path'])

            # General settings
            if 'GENERAL' in self.configuration and \
                    self.configuration['GENERAL']:
                if 'experience_logs' in self.configuration['GENERAL']:
                    dialogues_path = None
                    if 'path' in \
                            self.configuration['GENERAL']['experience_logs']:
                        dialogues_path = \
                            self.configuration['GENERAL'][
                                'experience_logs']['path']

                    if 'load' in \
                            self.configuration['GENERAL']['experience_logs'] \
                        and bool(
                            self.configuration['GENERAL'][
                                'experience_logs']['load']
                    ):
                        if dialogues_path and os.path.isfile(dialogues_path):
                            self.recorder.load(dialogues_path)
                        else:
                            raise FileNotFoundError(
                                'Dialogue Log file %s not found (did you '
                                'provide one?)' % dialogues_path)

                    if 'save' in \
                            self.configuration['GENERAL']['experience_logs']:
                        self.recorder.set_path(dialogues_path)
                        self.SAVE_LOG = bool(self.configuration['GENERAL']
                                             ['experience_logs']['save'])

                if self.configuration['GENERAL']['interaction_mode'] == \
                        'simulation':
                    self.USE_USR_SIMULATOR = True

                elif self.configuration['GENERAL']['interaction_mode'] == \
                        'speech':
                    self.USE_SPEECH = True
                    self.asr = speech_rec.Recognizer()

            # Agent Settings

            # Usr Simulator
            # Check for specific simulator self.settings, otherwise
            # default to agenda
            if 'USER_SIMULATOR' in self.configuration['AGENT_0']:
                # Agent 0 simulator configuration
                a0_sim_config = self.configuration['AGENT_0']['USER_SIMULATOR']
                if a0_sim_config and a0_sim_config['simulator']:
                    # Default settings
                    self.user_simulator_args['ontology'] = self.ontology
                    self.user_simulator_args['database'] = self.database
                    self.user_simulator_args['um'] = self.user_model
                    self.user_simulator_args['patience'] = 5

                    if a0_sim_config['simulator'] == 'agenda':
                        if 'patience' in a0_sim_config:
                            self.user_simulator_args['patience'] = \
                                int(a0_sim_config['patience'])

                        if 'pop_distribution' in a0_sim_config:
                            if isinstance(a0_sim_config['pop_distribution'],
                                          list):
                                self.user_simulator_args['pop_distribution'] =\
                                    a0_sim_config['pop_distribution']
                            else:
                                self.user_simulator_args['pop_distribution'] =\
                                    eval(a0_sim_config['pop_distribution'])

                        if 'slot_confuse_prob' in a0_sim_config:
                            self.user_simulator_args['slot_confuse_prob'] = \
                                float(a0_sim_config['slot_confuse_prob'])
                        if 'op_confuse_prob' in a0_sim_config:
                            self.user_simulator_args['op_confuse_prob'] = \
                                float(a0_sim_config['op_confuse_prob'])
                        if 'value_confuse_prob' in a0_sim_config:
                            self.user_simulator_args['value_confuse_prob'] = \
                                float(a0_sim_config['value_confuse_prob'])

                        if 'goal_slot_selection_weights' in a0_sim_config:
                            self.user_simulator_args[
                                'goal_slot_selection_weights'] = a0_sim_config[
                                    'goal_slot_selection_weights']

                        if 'nlu' in a0_sim_config:
                            self.user_simulator_args['nlu'] = \
                                a0_sim_config['nlu']

                            if self.user_simulator_args['nlu'] == 'dummy':
                                self.user_simulator_args['database'] = \
                                    self.database

                            self.USER_SIMULATOR_NLU = True

                        if 'nlg' in a0_sim_config:
                            self.user_simulator_args['nlg'] = \
                                a0_sim_config['nlg']

                            if self.user_simulator_args['nlg'] == 'CamRest':
                                if a0_sim_config:
                                    self.user_simulator_args[
                                        'nlg_model_path'] = a0_sim_config[
                                            'nlg_model_path']

                                    self.USER_SIMULATOR_NLG = True

                                else:
                                    raise ValueError(
                                        'Usr Simulator NLG: Cannot find '
                                        'model_path in the config.')

                            elif self.user_simulator_args['nlg'] == 'dummy':
                                self.USER_SIMULATOR_NLG = True

                        if 'goals_file' in a0_sim_config:
                            self.user_simulator_args['goals_file'] = \
                                a0_sim_config['goals_file']

                        if 'policy_file' in a0_sim_config:
                            self.user_simulator_args['policy_file'] = \
                                a0_sim_config['policy_file']

                        self.user_simulator = AgendaBasedUS(
                            self.user_simulator_args)

                    elif a0_sim_config['simulator'] == 'dtl':
                        if 'policy_file' in a0_sim_config:
                            self.user_simulator_args['policy_file'] = \
                                a0_sim_config['policy_file']
                            self.user_simulator = DTLUserSimulator(
                                self.user_simulator_args)
                        else:
                            raise ValueError(
                                'Error! Cannot start DAct-to-Language '
                                'simulator without a policy file!')

                else:
                    # Fallback to agenda based simulator with default settings
                    self.user_simulator = AgendaBasedUS(
                        self.user_simulator_args)

            # NLU Settings
            if 'NLU' in self.configuration['AGENT_0'] and \
                    self.configuration['AGENT_0']['NLU'] and \
                    self.configuration['AGENT_0']['NLU']['nlu']:
                nlu_args = dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))

                if self.configuration['AGENT_0']['NLU']['nlu'] == 'dummy':
                    self.nlu = DummyNLU(nlu_args)

                elif self.configuration['AGENT_0']['NLU']['nlu'] == 'CamRest':
                    if self.configuration['AGENT_0']['NLU']['model_path']:
                        nlu_args['model_path'] = \
                            self.configuration['AGENT_0']['NLU']['model_path']
                        self.nlu = CamRestNLU(nlu_args)
                    else:
                        raise ValueError(
                            'Cannot find model_path in the config.')

            # NLG Settings
            if 'NLG' in self.configuration['AGENT_0'] and \
                    self.configuration['AGENT_0']['NLG'] and \
                    self.configuration['AGENT_0']['NLG']['nlg']:
                if self.configuration['AGENT_0']['NLG']['nlg'] == 'dummy':
                    self.nlg = DummyNLG()

                elif self.configuration['AGENT_0']['NLG']['nlg'] == 'CamRest':
                    if self.configuration['AGENT_0']['NLG']['model_path']:
                        self.nlg = CamRestNLG({
                            'model_path':
                            self.configuration['AGENT_0']['NLG']['model_path']
                        })
                    else:
                        raise ValueError(
                            'Cannot find model_path in the config.')

                if self.nlg:
                    self.USE_NLG = True

            # Retrieve agent role
            if 'role' in self.configuration['AGENT_0']:
                self.agent_role = self.configuration['AGENT_0']['role']
            else:
                raise ValueError(
                    'ConversationalAgent: No role assigned for agent {0} in '
                    'config!'.format(self.agent_id))

            if self.agent_role == 'user':
                if self.ontology and self.database:
                    self.goal_generator = GoalGenerator(ontology=self.ontology,
                                                        database=self.database)
                else:
                    raise ValueError(
                        'Conversational Multi Agent (user): Cannot generate '
                        'goal without ontology and database.')

        dm_args = dict(
            zip([
                'settings', 'ontology', 'database', 'domain', 'agent_id',
                'agent_role'
            ], [
                self.configuration, self.ontology, self.database, self.domain,
                self.agent_id, self.agent_role
            ]))
        dm_args.update(self.configuration['AGENT_0']['DM'])
        self.dialogue_manager = DialogueManager.DialogueManager(dm_args)

    def __del__(self):
        """
        Do some house-keeping and save the models.

        :return: nothing
        """

        if self.recorder and self.SAVE_LOG:
            self.recorder.save()

        if self.dialogue_manager:
            self.dialogue_manager.save()

        self.curr_state = None
        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

    def initialize(self):
        """
        Initializes the conversational agent based on settings in the
        configuration file.

        :return: Nothing
        """

        self.dialogue_episode = 0
        self.dialogue_turn = 0
        self.num_successful_dialogues = 0
        self.num_task_success = 0
        self.cumulative_rewards = 0

        if self.nlu:
            self.nlu.initialize({})

        if self.agent_role == 'user' and not self.agent_goal:
            self.agent_goal = self.goal_generator.generate()
            self.dialogue_manager.initialize({'goal': self.agent_goal})

        else:
            self.dialogue_manager.initialize({})

        if self.nlg:
            self.nlg.initialize({})

        self.curr_state = None
        self.prev_state = None
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

    def start_dialogue(self, args=None):
        """
        Perform initial dialogue turn.

        :param args: optional args
        :return:
        """

        self.dialogue_turn = 0
        sys_utterance = ''

        if self.USE_USR_SIMULATOR:
            self.user_simulator.initialize(self.user_simulator_args)

            print('DEBUG > Usr goal:')
            print(self.user_simulator.goal)

        if self.agent_role == 'user':
            self.agent_goal = self.goal_generator.generate()
            self.dialogue_manager.restart({'goal': self.agent_goal})

        else:
            self.dialogue_manager.restart({})

        if not self.USER_HAS_INITIATIVE:
            # sys_response = self.dialogue_manager.respond()
            sys_response = [DialogueAct('welcomemsg', [])]

            if self.USE_NLG:
                sys_utterance = self.nlg.generate_output(
                    {'dacts': sys_response})
                print('SYSTEM > %s ' % sys_utterance)

                if self.USE_SPEECH:
                    try:
                        tts = gTTS(sys_utterance)
                        tts.save('sys_output.mp3')
                        os.system('afplay sys_output.mp3')

                    except Exception as e:
                        print('WARNING: gTTS encountered an error: {0}. '
                              'Falling back to Sys TTS.'.format(e))
                        os.system('say ' + sys_utterance)
            else:
                print('SYSTEM > %s ' %
                      '; '.join([str(sr) for sr in sys_response]))

            if self.USE_USR_SIMULATOR:
                usim_input = sys_response

                if self.USER_SIMULATOR_NLU and self.USE_NLG:
                    usim_input = self.user_simulator.nlu.process_input(
                        sys_utterance)

                self.user_simulator.receive_input(usim_input)
                rew, success, task_success = self.reward_func.calculate(
                    self.dialogue_manager.get_state(), sys_response,
                    self.user_simulator.goal)
            else:
                rew, success, task_success = 0, None, None

            self.recorder.record(deepcopy(self.dialogue_manager.get_state()),
                                 self.dialogue_manager.get_state(),
                                 sys_response,
                                 rew,
                                 success,
                                 task_success,
                                 output_utterance=sys_utterance)

            self.dialogue_turn += 1

        self.prev_state = None

        # Re-initialize these for good measure
        self.curr_state = None
        self.prev_usr_utterance = None
        self.prev_sys_utterance = None
        self.prev_action = None
        self.prev_reward = None
        self.prev_success = None
        self.prev_task_success = None

        self.continue_dialogue()

    def continue_dialogue(self):
        """
        Perform next dialogue turn.

        :return: nothing
        """

        usr_utterance = ''
        sys_utterance = ''

        if self.USE_USR_SIMULATOR:
            usr_input = self.user_simulator.respond()

            # TODO: THIS FIRST IF WILL BE HANDLED BY ConversationalAgentGeneric
            #  -- SHOULD NOT LIVE HERE
            if isinstance(self.user_simulator, DTLUserSimulator):
                print('USER (NLG) > %s \n' % usr_input)
                usr_input = self.nlu.process_input(
                    usr_input, self.dialogue_manager.get_state())

            elif self.USER_SIMULATOR_NLG:
                print('USER > %s \n' % usr_input)

                if self.nlu:
                    usr_input = self.nlu.process_input(usr_input)

                    # Otherwise it will just print the user's NLG but use the
                    # simulator's output DActs to proceed.

            else:
                print('USER (DACT) > %s \n' % usr_input[0])

        else:
            if self.USE_SPEECH:
                # Listen for input from the microphone
                with speech_rec.Microphone() as source:
                    print('(listening...)')
                    audio = self.asr.listen(source, phrase_time_limit=3)

                try:
                    # This uses the default key
                    usr_utterance = self.asr.recognize_google(audio)
                    print("Google ASR: " + usr_utterance)

                except speech_rec.UnknownValueError:
                    print("Google ASR did not understand you")

                except speech_rec.RequestError as e:
                    print("Google ASR request error: {0}".format(e))

            else:
                usr_utterance = input('USER > ')

            # Process the user's utterance
            if self.nlu:
                usr_input = self.nlu.process_input(
                    usr_utterance, self.dialogue_manager.get_state())
            else:
                raise EnvironmentError(
                    'ConversationalAgent: No NLU defined for '
                    'text-based interaction!')

        # DEBUG print
        # print(
        #     '\nSYSTEM NLU > %s ' % '; '.join([str(ui) for ui in usr_input])
        # )

        self.dialogue_manager.receive_input(usr_input)

        # Keep track of prev_state, for the DialogueEpisodeRecorder
        # Store here because this is the state that the dialogue manager
        # will use to make a decision.
        self.curr_state = deepcopy(self.dialogue_manager.get_state())

        # print('\nDEBUG> '+str(self.dialogue_manager.get_state()) + '\n')

        if self.dialogue_turn < self.MAX_TURNS:
            sys_response = self.dialogue_manager.generate_output()

        else:
            # Force dialogue stop
            # print(
            #     '{0}: terminating dialogue due to too '
            #     'many turns'.format(self.agent_role)
            # )
            sys_response = [DialogueAct('bye', [])]

        if self.USE_NLG:
            sys_utterance = self.nlg.generate_output({'dacts': sys_response})
            print('SYSTEM > %s ' % sys_utterance)

            if self.USE_SPEECH:
                try:
                    tts = gTTS(text=sys_utterance, lang='en')
                    tts.save('sys_output.mp3')
                    os.system('afplay sys_output.mp3')

                except:
                    print('WARNING: gTTS encountered an error. '
                          'Falling back to Sys TTS.')
                    os.system('say ' + sys_utterance)
        else:
            print('SYSTEM > %s ' % '; '.join([str(sr) for sr in sys_response]))

        if self.USE_USR_SIMULATOR:
            usim_input = sys_response

            if self.USER_SIMULATOR_NLU and self.USE_NLG:
                usim_input = \
                    self.user_simulator.nlu.process_input(sys_utterance)

                print('USER NLU '
                      '> %s ' % '; '.join([str(ui) for ui in usim_input]))

            self.user_simulator.receive_input(usim_input)
            rew, success, task_success = \
                self.reward_func.calculate(
                    self.dialogue_manager.get_state(),
                    sys_response,
                    self.user_simulator.goal
                )
        else:
            rew, success, task_success = 0, None, None

        if self.prev_state:
            self.recorder.record(self.prev_state,
                                 self.curr_state,
                                 self.prev_action,
                                 self.prev_reward,
                                 self.prev_success,
                                 input_utterance=usr_utterance,
                                 output_utterance=sys_utterance)

        self.dialogue_turn += 1

        self.prev_state = deepcopy(self.curr_state)
        self.prev_action = deepcopy(sys_response)
        self.prev_usr_utterance = deepcopy(usr_utterance)
        self.prev_sys_utterance = deepcopy(sys_utterance)
        self.prev_reward = rew
        self.prev_success = success
        self.prev_task_success = task_success

    def end_dialogue(self):
        """
        Perform final dialogue turn. Train and save models if applicable.

        :return: nothing
        """

        # Record final state
        self.recorder.record(self.curr_state,
                             self.curr_state,
                             self.prev_action,
                             self.prev_reward,
                             self.prev_success,
                             input_utterance=self.prev_usr_utterance,
                             output_utterance=self.prev_sys_utterance,
                             task_success=self.prev_task_success)

        if self.dialogue_manager.is_training():
            if self.dialogue_episode % self.train_interval == 0 and \
                    len(self.recorder.dialogues) >= self.minibatch_length:
                for epoch in range(self.train_epochs):
                    print('Training epoch {0} of {1}'.format(
                        epoch, self.train_epochs))

                    # Sample minibatch
                    minibatch = random.sample(self.recorder.dialogues,
                                              self.minibatch_length)

                    if self.nlu:
                        self.nlu.train(minibatch)

                    self.dialogue_manager.train(minibatch)

                    if self.nlg:
                        self.nlg.train(minibatch)

        self.dialogue_episode += 1
        self.cumulative_rewards += \
            self.recorder.dialogues[-1][-1]['cumulative_reward']
        print('CUMULATIVE REWARD: {0}'.format(
            self.recorder.dialogues[-1][-1]['cumulative_reward']))

        if self.dialogue_turn > 0:
            self.total_dialogue_turns += self.dialogue_turn

        if self.dialogue_episode % 10000 == 0:
            self.dialogue_manager.save()

        # Count successful dialogues
        if self.recorder.dialogues[-1][-1]['success']:
            print('SUCCESS (Subjective)!')
            self.num_successful_dialogues += \
                int(self.recorder.dialogues[-1][-1]['success'])

        else:
            print('FAILURE (Subjective).')

        if self.recorder.dialogues[-1][-1]['task_success']:
            self.num_task_success += \
                int(self.recorder.dialogues[-1][-1]['task_success'])

        print('OBJECTIVE TASK SUCCESS: {0}'.format(
            self.recorder.dialogues[-1][-1]['task_success']))

    def terminated(self):
        """
        Check if this agent is at a terminal state.

        :return: True or False
        """

        return self.dialogue_manager.at_terminal_state()
Exemplo n.º 8
0
    def __init__(self,
                 ontology,
                 database,
                 agent_id=0,
                 agent_role='system',
                 alpha=0.25,
                 gamma=0.95,
                 epsilon=0.25,
                 alpha_decay=0.9995,
                 epsilon_decay=0.995,
                 epsilon_min=0.05,
                 warm_up_mode=False,
                 **kwargs):
        """
        Initialize parameters and internal structures

        :param ontology: the domain's ontology
        :param database: the domain's database
        :param agent_id: the agent's id
        :param agent_role: the agent's role
        :param alpha: the learning rate
        :param gamma: the discount rate
        :param epsilon: the exploration rate
        :param alpha_decay: the learning rate discount rate
        :param epsilon_decay: the exploration rate discount rate
        """

        self.logger = logging.getLogger(__name__)
        self.warm_up_mode = warm_up_mode
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.alpha_decay = alpha_decay
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min

        self.IS_GREEDY_POLICY = False

        # TODO: Put these as arguments in the config
        self.d_win = 0.0025
        self.d_lose = 0.01

        self.is_training = False

        self.agent_id = agent_id
        self.agent_role = agent_role

        self.ontology = None
        if isinstance(ontology, Ontology):
            self.ontology = ontology
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase):
            self.database = database
        else:
            raise ValueError('WoLF PHC DialoguePolicy: Unacceptable database '
                             'type %s ' % database)

        self.Q = {}
        self.pi = {}
        self.mean_pi = {}
        self.state_counter = {}

        self.pp = pprint.PrettyPrinter(width=160)  # For debug!

        # System and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        if self.agent_role == 'system':
            # Put your system expert policy here
            self.warmup_policy = \
                HandcraftedPolicy.HandcraftedPolicy(self.ontology)

        elif self.agent_role == 'user':
            usim_args = dict(
                zip(['ontology', 'database'], [self.ontology, self.database]))
            # Put your user expert policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        # Plato does not use action masks (rules to define which
        # actions are valid from each state) and so training can
        # be harder. This becomes easier if we have a smaller
        # action set.

        # Extract lists of slots that are frequently used
        self.informable_slots = deepcopy(
            list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = deepcopy(
            self.ontology.ontology['requestable'])
        self.system_requestable_slots = deepcopy(
            self.ontology.ontology['system_requestable'])

        self.statistics = {'supervised_turns': 0, 'total_turns': 0}

        self.hash2actions = {}

        self.domain = setup_domain(self.ontology)
        self.NActions = self.domain.NActions
Exemplo n.º 9
0
class WoLFPHCPolicy(DialoguePolicy.DialoguePolicy):
    def __init__(self,
                 ontology,
                 database,
                 agent_id=0,
                 agent_role='system',
                 alpha=0.25,
                 gamma=0.95,
                 epsilon=0.25,
                 alpha_decay=0.9995,
                 epsilon_decay=0.995,
                 epsilon_min=0.05,
                 warm_up_mode=False,
                 **kwargs):
        """
        Initialize parameters and internal structures

        :param ontology: the domain's ontology
        :param database: the domain's database
        :param agent_id: the agent's id
        :param agent_role: the agent's role
        :param alpha: the learning rate
        :param gamma: the discount rate
        :param epsilon: the exploration rate
        :param alpha_decay: the learning rate discount rate
        :param epsilon_decay: the exploration rate discount rate
        """

        self.logger = logging.getLogger(__name__)
        self.warm_up_mode = warm_up_mode
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.alpha_decay = alpha_decay
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min

        self.IS_GREEDY_POLICY = False

        # TODO: Put these as arguments in the config
        self.d_win = 0.0025
        self.d_lose = 0.01

        self.is_training = False

        self.agent_id = agent_id
        self.agent_role = agent_role

        self.ontology = None
        if isinstance(ontology, Ontology):
            self.ontology = ontology
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase):
            self.database = database
        else:
            raise ValueError('WoLF PHC DialoguePolicy: Unacceptable database '
                             'type %s ' % database)

        self.Q = {}
        self.pi = {}
        self.mean_pi = {}
        self.state_counter = {}

        self.pp = pprint.PrettyPrinter(width=160)  # For debug!

        # System and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        if self.agent_role == 'system':
            # Put your system expert policy here
            self.warmup_policy = \
                HandcraftedPolicy.HandcraftedPolicy(self.ontology)

        elif self.agent_role == 'user':
            usim_args = dict(
                zip(['ontology', 'database'], [self.ontology, self.database]))
            # Put your user expert policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        # Plato does not use action masks (rules to define which
        # actions are valid from each state) and so training can
        # be harder. This becomes easier if we have a smaller
        # action set.

        # Extract lists of slots that are frequently used
        self.informable_slots = deepcopy(
            list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = deepcopy(
            self.ontology.ontology['requestable'])
        self.system_requestable_slots = deepcopy(
            self.ontology.ontology['system_requestable'])

        self.statistics = {'supervised_turns': 0, 'total_turns': 0}

        self.hash2actions = {}

        self.domain = setup_domain(self.ontology)
        self.NActions = self.domain.NActions

    def initialize(self, **kwargs):
        """
        Initialize internal structures at the beginning of each dialogue

        :return: Nothing
        """

        if 'is_training' in kwargs:
            self.is_training = bool(kwargs['is_training'])

            if 'learning_rate' in kwargs:
                self.alpha = float(kwargs['learning_rate'])

            if 'learning_decay_rate' in kwargs:
                self.alpha_decay = float(kwargs['learning_decay_rate'])

            if 'exploration_rate' in kwargs:
                self.epsilon = float(kwargs['exploration_rate'])

            if 'exploration_decay_rate' in kwargs:
                self.epsilon_decay = float(kwargs['exploration_decay_rate'])

            if 'gamma' in kwargs:
                self.gamma = float(kwargs['gamma'])

            if self.agent_role == 'user' and self.warmup_simulator:
                if 'goal' in kwargs:
                    self.warmup_simulator.initialize({kwargs['goal']})
                else:
                    self.logger.warning(
                        'WARNING ! No goal provided for Supervised policy '
                        'user simulator @ initialize')
                    self.warmup_simulator.initialize({})

        self.counter = {'warmup': 0, 'learned': 0, 'random': 0}

    def restart(self, args):
        """
        Re-initialize relevant parameters / variables at the beginning of each
        dialogue.

        :return: nothing
        """

        if self.agent_role == 'user' and self.warmup_simulator:
            if 'goal' in args:
                self.warmup_simulator.initialize(args)
            else:
                self.logger.warning(
                    'WARNING! No goal provided for Supervised policy user '
                    'simulator @ restart')
                self.warmup_simulator.initialize({})

    def next_action(self, state):
        """
        Consults the policy to produce the agent's response

        :param state: the current dialogue state
        :return: a list of dialogue acts, representing the agent's response
        """

        # state_enc = self.encode_state(state)
        state_enc = state_to_json(state)
        self.statistics['total_turns'] += 1

        if state_enc not in self.pi or \
                (self.is_training and random.random() < self.epsilon):
            # if not self.is_training:
            #     if not self.pi:
            #         self.logger.warning(f'\nWARNING! WoLF-PHC pi is empty '
            #                             f'({self.agent_role}). Did you load the correct '
            #                             f'file?\n')
            #     else:
            #         self.logger.warning(f'\nWARNING! WoLF-PHC state not found in policy '
            #                             f'pi ({self.agent_role}).\n')
            threshold = 1.0 if self.warm_up_mode else 0.5
            if self.is_training and random.random() < threshold:
                # use warm up / hand crafted only in training
                self.logger.debug('--- {0}: Selecting warmup action.'.format(
                    self.agent_role))
                self.statistics['supervised_turns'] += 1

                if self.agent_role == 'system':
                    self.counter['warmup'] += 1
                    return self.warmup_policy.next_action(state)

                else:
                    self.warmup_simulator.receive_input(
                        state.user_acts, state.user_goal)
                    return self.warmup_simulator.respond()
            else:
                self.logger.debug('--- {0}: Selecting random action.'.format(
                    self.agent_role))
                sys_acts = create_random_dialog_act(self.domain,
                                                    is_system=True)
                self.counter['random'] += 1
                return sys_acts

        if self.IS_GREEDY_POLICY:
            # Get greedy action
            # Do not consider 'UNK' or an empty action
            state_actions = {}
            for k, v in self.pi[state_enc].items():
                if k and len(k) > 0:
                    state_actions[k] = v

            if len(state_actions) < 1:
                self.logger.warning(
                    '--- {0}: Warning! No maximum value identified for '
                    'policy. Selecting random action.'.format(self.agent_role))

                sys_acts = create_random_dialog_act(self.domain,
                                                    is_system=True)
                self.counter['random'] += 1
            else:

                # find all actions with same max_value
                max_value = max(state_actions.values())
                max_actions = [
                    k for k, v in state_actions.items() if v == max_value
                ]

                # break ties randomly
                action = random.choice(max_actions)
                sys_acts = self.decode_action(action, system=True)
                self.counter['learned'] += 1

        else:
            # Sample next action
            action_from_pi = random.choices(
                list(self.pi[state_enc].keys()),
                list(self.pi[state_enc].values()))[0]
            sys_acts = self.decode_action(action_from_pi,
                                          self.agent_role == 'system')

        assert sys_acts is not None
        return sys_acts

    def encode_action(self, acts: List[DialogueAct], system=True) -> str:
        """
        Encode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be encoding another agent's action
        (e.g. a system encoding the previous user act).

        :param actions: actions to be encoded
        :param system: whether the role whose action we are encoding is a
                       'system'
        :return: the encoded action
        """

        acts_copy = [deepcopy(x) for x in acts]
        for act in acts_copy:
            if act.params:
                for item in act.params:
                    if item.slot and item.value:
                        item.value = None

        s = action_to_string(acts_copy, system)
        # enc = int(hashlib.sha1(s.encode('utf-8')).hexdigest(), 32)
        self.hash2actions[s] = acts_copy
        return s

    def decode_action(self, action_enc: str, system=True) -> List[DialogueAct]:
        """
        Decode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be decoding another agent's action
        (e.g. a system decoding the previous user act).

        :param action_enc: action encoding to be decoded
        :param system: whether the role whose action we are decoding is a
                       'system'
        :return: the decoded action
        """

        return self.hash2actions[action_enc]

    def train(self, dialogues):
        """
        Train the model using WoLF-PHC.

        :param dialogues: a list dialogues, which is a list of dialogue turns
                         (state, action, reward triplets).
        :return:
        """

        if not self.is_training:
            return

        for dialogue in dialogues:
            if len(dialogue) > 1:
                dialogue[-2]['reward'] = dialogue[-1]['reward']

            for turn in dialogue:
                state_enc = state_to_json(turn['state'])
                new_state_enc = state_to_json(turn['new_state'])
                action_enc = \
                    self.encode_action(
                        turn['action'],
                        self.agent_role == 'system')

                # Skip unrecognised actions
                if not action_enc or turn['action'][0].intent == 'bye':
                    continue

                if state_enc not in self.Q:
                    self.Q[state_enc] = {}

                if action_enc not in self.Q[state_enc]:
                    self.Q[state_enc][action_enc] = 0

                if new_state_enc not in self.Q:
                    self.Q[new_state_enc] = {}

                # add the current action to new_state to have have at least one value for the new state when updating Q
                # if action_enc not in self.Q[new_state_enc]:
                #    self.Q[new_state_enc][action_enc] = 0

                if state_enc not in self.pi:
                    # self.pi[state_enc] = \
                    #    [float(1/self.NActions)] * self.NActions
                    self.pi[state_enc] = {}

                if action_enc not in self.pi[state_enc]:
                    self.pi[state_enc][action_enc] = float(1 / self.NActions)

                if state_enc not in self.mean_pi:
                    #self.mean_pi[state_enc] = \
                    #    [float(1/self.NActions)] * self.NActions
                    self.mean_pi[state_enc] = {}

                if action_enc not in self.mean_pi[state_enc]:
                    self.mean_pi[state_enc][action_enc] = float(1 /
                                                                self.NActions)

                if state_enc not in self.state_counter:
                    self.state_counter[state_enc] = 1
                else:
                    self.state_counter[state_enc] += 1

                # Update Q
                max_new_state = max(self.Q[new_state_enc].values()) if len(
                    self.Q[new_state_enc]) > 0 else 0
                self.Q[state_enc][action_enc] = \
                    ((1 - self.alpha) * self.Q[state_enc][action_enc]) + \
                    self.alpha * (
                            turn['reward'] +
                            (self.gamma * max_new_state))

                # Update mean policy estimate
                # for a in range(self.NActions):
                for a in self.mean_pi[state_enc].keys():
                    self.mean_pi[state_enc][a] = \
                        self.mean_pi[state_enc][a] + \
                        ((1.0 / self.state_counter[state_enc]) *
                         (self.pi[state_enc][a] - self.mean_pi[state_enc][a]))

                # Determine delta
                sum_policy = 0.0
                sum_mean_policy = 0.0

                # for a in range(self.NActions):
                for a in self.Q[state_enc].keys():
                    sum_policy = sum_policy + (self.pi[state_enc][a] *
                                               self.Q[state_enc][a])
                    sum_mean_policy = \
                        sum_mean_policy + \
                        (self.mean_pi[state_enc][a] * self.Q[state_enc][a])

                if sum_policy > sum_mean_policy:
                    delta = self.d_win
                else:
                    delta = self.d_lose

                # Update policy estimate
                max_action_Q = max(self.Q[state_enc],
                                   key=self.Q[state_enc].get)

                d_plus = delta
                d_minus = ((-1.0) * d_plus) / (self.NActions - 1.0)

                # for a in range(self.NActions):
                for a in self.Q[state_enc].keys():
                    if a == max_action_Q:
                        self.pi[state_enc][a] = \
                            min(1.0, self.pi[state_enc][a] + d_plus)
                    else:
                        self.pi[state_enc][a] = \
                            max(0.0, self.pi[state_enc][a] + d_minus)

                # Constrain pi to a legal probability distribution
                # use max, as NActions is rather an estimate...
                num_unseen_actions = max(
                    self.NActions - len(self.pi[state_enc]), 0)
                sum_unseen_actions = num_unseen_actions * float(
                    1 / self.NActions)
                sum_pi = sum(self.pi[state_enc].values()) + sum_unseen_actions
                # for a in range(self.NActions):
                for a in self.pi[state_enc].keys():
                    self.pi[state_enc][a] /= sum_pi

        # Decay learning rate after each episode
        if self.alpha > 0.001:
            self.alpha *= self.alpha_decay

        # Decay exploration rate after each episode
        self.decay_epsilon()

        self.logger.info('[alpha: {0}, epsilon: {1}]'.format(
            self.alpha, self.epsilon))

    def decay_epsilon(self):
        """
        Decays epsilon (exploration rate) by epsilon decay.

         Decays epsilon (exploration rate) by epsilon decay.
         If epsilon is already less or equal compared to epsilon_min,
         the call of this method has no effect.

        :return:
        """
        if self.epsilon > self.epsilon_min and not self.warm_up_mode:
            self.epsilon *= self.epsilon_decay

    def save(self, path=None):
        """
        Saves the policy model to the path provided

        :param path: path to save the model to
        :return:
        """

        # Don't save if not training
        if not self.is_training:
            return

        if not path:
            path = 'Models/Policies/wolf_phc_policy.pkl'
            self.logger.warning(
                'No policy file name provided. Using default: {0}'.format(
                    path))

        obj = {
            'Q': self.Q,
            'pi': self.pi,
            'mean_pi': self.mean_pi,
            'state_counter': self.state_counter,
            'a': self.alpha,
            'e': self.epsilon,
            'e_decay': self.epsilon_decay,
            'e_min': self.epsilon_min,
            'g': self.gamma,
            'hash2actions': self.hash2actions
        }

        with open(path, 'wb') as file:
            pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)

        if self.statistics['total_turns'] > 0:
            self.logger.debug(
                '{0} WoLF PHC DialoguePolicy supervision ratio: {1}'.format(
                    self.agent_role,
                    float(self.statistics['supervised_turns'] /
                          self.statistics['total_turns'])))

        self.logger.debug(
            f'{self.agent_role} WoLF PHC DialoguePolicy state space '
            f'size: {len(self.pi)}')

    def load(self, path=None):
        """
        Load the policy model from the path provided

        :param path: path to load the model from
        :return:
        """

        if not path:
            self.logger.info('No policy loaded.')
            return

        if isinstance(path, str):
            if os.path.isfile(path):
                with open(path, 'rb') as file:
                    obj = pickle.load(file)

                    if 'Q' in obj:
                        self.Q = obj['Q']
                    if 'pi' in obj:
                        self.pi = obj['pi']
                    if 'mean_pi' in obj:
                        self.mean_pi = obj['mean_pi']
                    if 'state_counter' in obj:
                        self.state_counter = obj['state_counter']
                    if 'a' in obj:
                        self.alpha = obj['a']
                    if 'e' in obj:
                        self.epsilon = obj['e']
                    if 'e_decay' in obj:
                        self.epsilon_decay = obj['e_decay']
                    if 'e_min' in obj:
                        self.epsilon_min = obj['e_min']
                    if 'g' in obj:
                        self.gamma = obj['g']
                    if 'hash2actions' in obj:
                        self.hash2actions = obj['hash2actions']

                    self.logger.info(
                        'WoLF-PHC DialoguePolicy loaded from {0}.'.format(
                            path))

            else:
                self.logger.warning(
                    'Warning! WoLF-PHC DialoguePolicy file %s not found' %
                    path)
        else:
            self.logger.warning(
                'Warning! Unacceptable value for WoLF-PHC policy file name: %s '
                % path)
Exemplo n.º 10
0
class QPolicy(DialoguePolicy.DialoguePolicy):
    def __init__(self,
                 ontology,
                 database,
                 agent_id=0,
                 agent_role='system',
                 domain=None,
                 alpha=0.95,
                 epsilon=0.95,
                 gamma=0.15,
                 alpha_decay=0.995,
                 epsilon_decay=0.995):
        """
        Initialize parameters and internal structures

        :param ontology: the domain's ontology
        :param database: the domain's database
        :param agent_id: the agent's id
        :param agent_role: the agent's role
        :param alpha: the learning rate
        :param gamma: the discount rate
        :param epsilon: the exploration rate
        :param alpha_decay: the learning rate discount rate
        :param epsilon_decay: the exploration rate discount rate
        """

        self.alpha = alpha
        self.alpha_decay = alpha_decay
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay

        self.is_training = False
        self.IS_GREEDY_POLICY = True

        self.agent_id = agent_id
        self.agent_role = agent_role

        self.ontology = None
        if isinstance(ontology, Ontology):
            self.ontology = ontology
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase):
            self.database = database
        else:
            raise ValueError('WoLF PHC DialoguePolicy: Unacceptable database '
                             'type %s ' % database)

        self.Q = {}

        self.pp = pprint.PrettyPrinter(width=160)  # For debug!

        # System and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        if self.agent_role == 'system':
            # Put your system expert policy here
            self.warmup_policy = \
                HandcraftedPolicy.HandcraftedPolicy(self.ontology)

        elif self.agent_role == 'user':
            usim_args = \
                dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))
            # Put your user expert policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        self.dstc2_acts = None

        if not domain:
            # Default to CamRest dimensions
            self.NStateFeatures = 56

            # Default to CamRest actions
            self.dstc2_acts = [
                'repeat', 'canthelp', 'affirm', 'negate', 'deny', 'ack',
                'thankyou', 'bye', 'reqmore', 'hello', 'welcomemsg',
                'expl-conf', 'select', 'offer', 'reqalts', 'confirm-domain',
                'confirm'
            ]

        else:
            # Try to identify number of state features
            if domain in ['SlotFilling', 'CamRest']:

                # Plato does not use action masks (rules to define which
                # actions are valid from each state) and so training can
                # be harder. This becomes easier if we have a smaller
                # action set.

                # Sub-case for CamRest
                if domain == 'CamRest':
                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_sys = [
                        'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye',
                        'reqmore', 'welcomemsg', 'expl-conf', 'select',
                        'repeat', 'confirm-domain', 'confirm'
                    ]

                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_usr = [
                        'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye',
                        'reqmore', 'hello', 'expl-conf', 'repeat', 'reqalts',
                        'restart', 'confirm'
                    ]

                    if self.agent_role == 'system':
                        self.dstc2_acts = self.dstc2_acts_sys

                    elif self.agent_role == 'user':
                        self.dstc2_acts = self.dstc2_acts_usr

                    self.NActions = \
                        len(self.dstc2_acts) + len(self.requestable_slots)

                    if self.agent_role == 'system':
                        self.NActions += len(self.system_requestable_slots)
                    else:
                        self.NActions += len(self.requestable_slots)

    def initialize(self, **kwargs):
        """
        Initialize internal parameters

        :return: Nothing
        """

        if 'is_training' in kwargs:
            self.is_training = bool(kwargs['is_training'])

        if 'agent_role' in kwargs:
            self.agent_role = kwargs['agent_role']

    def restart(self, args):
        """
        Nothing to do here.

        :return:
        """

        pass

    def next_action(self, state):
        """
        Consults the policy to produce the agent's response

        :param state: the current dialogue state
        :return: a list of dialogue acts, representing the agent's response
        """

        state_enc = self.encode_state(state)

        if state_enc not in self.Q or (self.is_training
                                       and random.random() < self.epsilon):

            if random.random() < 0.5:
                # During exploration we may want to follow another policy,
                # e.g. an expert policy.

                print('---: Selecting warmup action.')

                if self.agent_role == 'system':
                    return self.warmup_policy.next_action(state)
                else:
                    self.warmup_simulator.receive_input(
                        state.user_acts, state.user_goal)
                    return self.warmup_simulator.respond()

            else:
                # Return a random action
                print('---: Selecting random action')
                return self.decode_action(
                    random.choice(range(0, self.NActions)),
                    self.agent_role == 'system')

        if self.IS_GREEDY_POLICY:
            # Return action with maximum Q value from the given state
            sys_acts = self.decode_action(
                max(self.Q[state_enc], key=self.Q[state_enc].get),
                self.agent_role == 'system')
        else:
            sys_acts = self.decode_action(
                random.choices(range(0, self.NActions), self.Q[state_enc])[0],
                self.agent_role == 'system')

        return sys_acts

    def encode_state(self, state):
        """
        Encodes the dialogue state into an index used to address the Q matrix.

        :param state: the state to encode
        :return: int - a unique state ID
        """

        temp = []

        temp += [int(b) for b in format(state.turn, '06b')]

        for value in state.slots_filled.values():
            # This contains the requested slot
            temp.append(1) if value else temp.append(0)

        for slot in self.ontology.ontology['requestable']:
            temp.append(1) if slot == state.requested_slot else temp.append(0)

        temp.append(int(state.is_terminal_state))

        # If the agent is a system, then this shows what the top db result is.
        # If the agent is a user, then this shows what information the
        # system has provided
        if state.item_in_focus:
            for slot in self.ontology.ontology['requestable']:
                if slot in state.item_in_focus and state.item_in_focus[slot]:
                    temp.append(1)
                else:
                    temp.append(0)
        else:
            temp += [0] * len(self.ontology.ontology['requestable'])

        if state.db_matches_ratio >= 0:
            temp += \
                [int(b) for b in
                 format(int(round(state.db_matches_ratio, 2) * 100), '07b')]
        else:
            # If the number is negative (should not happen in general) there
            # will be a minus sign
            temp += \
                [int(b) for b in
                 format(int(round(state.db_matches_ratio, 2) * 100),
                        '07b')[1:]]

        temp.append(1) if state.system_made_offer else temp.append(0)

        if state.user_acts:
            temp += \
                [int(b) for b in
                 format(self.encode_action(state.user_acts, False), '05b')]
        else:
            temp += [0, 0, 0, 0, 0]

        if state.last_sys_acts:
            temp += \
                [int(b) for b in
                 format(self.encode_action([state.last_sys_acts[0]]), '04b')]
        else:
            temp += [0, 0, 0, 0]

        # If the agent plays the role of the user it needs access to its own
        # goal
        if state.user_goal:
            for c in self.ontology.ontology['informable']:
                if c in state.user_goal.constraints and \
                        state.user_goal.constraints[c].value:
                    temp.append(1)
                else:
                    temp.append(0)

            for r in self.ontology.ontology['requestable']:
                if r in state.user_goal.requests and \
                        state.user_goal.requests[r].value:
                    temp.append(1)
                else:
                    temp.append(0)
        else:
            # Just for symmetry, for all other roles append zeros
            temp += [0] * (len(self.ontology.ontology['informable']) +
                           len(self.ontology.ontology['requestable']))

        # Encode state
        state_enc = 0
        for t in temp:
            state_enc = (state_enc << 1) | t

        return state_enc

    def encode_action(self, actions, system=True):
        """
        Encode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be encoding another agent's action
        (e.g. a system encoding the previous user act).

        :param actions: actions to be encoded
        :param system: whether the role whose action we are encoding is a
                       'system'
        :return: the encoded action
        """

        # TODO: Handle multiple actions
        # TODO: Action encoding in a principled way
        if not actions:
            print('WARNING: Supervised DialoguePolicy action encoding called '
                  'with empty actions list (returning -1).')
            return -1

        action = actions[0]

        slot = None
        if action.params and action.params[0].slot:
            slot = action.params[0].slot

        if system:
            if self.dstc2_acts_sys and action.intent in self.dstc2_acts_sys:
                return self.dstc2_acts_sys.index(action.intent)

            if slot:
                if action.intent == 'request' and slot in \
                        self.system_requestable_slots:
                    return len(self.dstc2_acts_sys) + \
                           self.system_requestable_slots.index(slot)

                if action.intent == 'inform' and slot in \
                        self.requestable_slots:
                    return len(self.dstc2_acts_sys) + \
                           len(self.system_requestable_slots) + \
                           self.requestable_slots.index(slot)
        else:
            if self.dstc2_acts_usr and action.intent in self.dstc2_acts_usr:
                return self.dstc2_acts_usr.index(action.intent)

            if slot:
                if action.intent == 'request' and slot in \
                        self.requestable_slots:
                    return len(self.dstc2_acts_usr) + \
                           self.requestable_slots.index(slot)

                if action.intent == 'inform' and slot in \
                        self.requestable_slots:
                    return len(self.dstc2_acts_usr) + \
                           len(self.requestable_slots) + \
                           self.requestable_slots.index(slot)

        # Unable to encode action
        print('Q-Learning ({0}) policy action encoder warning: Selecting '
              'default action (unable to encode: {1})!'.format(
                  self.agent_role, action))
        return -1

    def decode_action(self, action_enc, system=True):
        """
        Decode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be decoding another agent's action
        (e.g. a system decoding the previous user act).

        :param action_enc: action encoding to be decoded
        :param system: whether the role whose action we are decoding is a
                       'system'
        :return: the decoded action
        """

        if system:
            if action_enc < len(self.dstc2_acts_sys):
                return [DialogueAct(self.dstc2_acts_sys[action_enc], [])]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.system_requestable_slots[
                                action_enc - len(self.dstc2_acts_sys)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots) + \
                    len(self.requestable_slots):
                index = \
                    action_enc - len(self.dstc2_acts_sys) - \
                    len(self.system_requestable_slots)
                return [
                    DialogueAct('inform', [
                        DialogueActItem(self.requestable_slots[index],
                                        Operator.EQ, '')
                    ])
                ]

        else:
            if action_enc < len(self.dstc2_acts_usr):
                return [DialogueAct(self.dstc2_acts_usr[action_enc], [])]

            if action_enc < len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.requestable_slots[action_enc -
                                                   len(self.dstc2_acts_usr)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots):
                return [
                    DialogueAct('inform', [
                        DialogueActItem(
                            self.requestable_slots[
                                action_enc - len(self.dstc2_acts_usr) -
                                len(self.requestable_slots)], Operator.EQ, '')
                    ])
                ]

    def train(self, dialogues):
        """
        Train the model using Q-learning.

        :param dialogues: a list dialogues, which is a list of dialogue turns
                          (state, action, reward triplets).
        :return:
        """

        for dialogue in dialogues:
            if len(dialogue) > 1:
                dialogue[-2]['reward'] = dialogue[-1]['reward']

            for turn in dialogue:
                state_enc = self.encode_state(turn['state'])
                new_state_enc = self.encode_state(turn['new_state'])
                action_enc = self.encode_action(turn['action'])

                if action_enc < 0:
                    continue

                if state_enc not in self.Q:
                    self.Q[state_enc] = {}

                if action_enc not in self.Q[state_enc]:
                    self.Q[state_enc][action_enc] = 0

                max_q = 0
                if new_state_enc in self.Q:
                    max_q = max(self.Q[new_state_enc].values())

                self.Q[state_enc][action_enc] += \
                    self.alpha * (turn['reward'] +
                                  self.gamma * max_q -
                                  self.Q[state_enc][action_enc])

        # Decay learning rate
        if self.alpha > 0.001:
            self.alpha *= self.alpha_decay

        # Decay exploration rate
        if self.epsilon > 0.05:
            self.epsilon *= self.epsilon_decay

        print('Q-Learning: [alpha: {0}, epsilon: {1}]'.format(
            self.alpha, self.epsilon))

    def save(self, path=None):
        """
        Save the Q learning policy model

        :param path: the path to save the model to
        :return: nothing
        """

        # Don't save if not training
        if not self.is_training:
            return

        if not path:
            path = 'Models/Policies/q_policy.pkl'
            print('No policy file name provided. Using default: {0}'.format(
                path))

        obj = {
            'Q': self.Q,
            'a': self.alpha,
            'e': self.epsilon,
            'g': self.gamma
        }

        with open(path, 'wb') as file:
            pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)

    def load(self, path=None):
        """
        Loads the Q learning policy model

        :param path: the path to load the model from
        :return: nothing
        """

        if not path:
            print('No policy loaded.')
            return

        if isinstance(path, str):
            if os.path.isfile(path):
                with open(path, 'rb') as file:
                    obj = pickle.load(file)

                    if 'Q' in obj:
                        self.Q = obj['Q']
                    if 'a' in obj:
                        self.alpha = obj['a']
                    if 'e' in obj:
                        self.epsilon = obj['e']
                    if 'g' in obj:
                        self.gamma = obj['g']

                    print('Q DialoguePolicy loaded from {0}.'.format(path))

            else:
                print('Warning! Q DialoguePolicy file %s not found' % path)
        else:
            print('Warning! Unacceptable value for Q policy file name: %s ' %
                  path)
Exemplo n.º 11
0
class SupervisedPolicy(DialoguePolicy.DialoguePolicy):
    def __init__(self,
                 ontology,
                 database,
                 agent_id=0,
                 agent_role='system',
                 domain=None):
        """
        Initialize parameters and internal structures

        :param ontology: the domain's ontology
        :param database: the domain's database
        :param agent_id: the agent's id
        :param agent_role: the agent's role
        :param domain: the dialogue's domain
        """
        super(SupervisedPolicy, self).__init__()

        self.agent_id = agent_id
        self.agent_role = agent_role

        # True for greedy, False for stochastic
        self.IS_GREEDY_POLICY = False

        self.ontology = None
        if isinstance(ontology, Ontology.Ontology):
            self.ontology = ontology
        else:
            raise ValueError('Supervised DialoguePolicy: Unacceptable '
                             'ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase.DataBase):
            self.database = database
        else:
            raise ValueError('Supervised DialoguePolicy: Unacceptable '
                             'database type %s ' % database)

        self.policy_path = None

        self.policy_net = None
        self.tf_scope = "policy_" + self.agent_role + '_' + str(self.agent_id)
        self.sess = None

        # The system and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        # Default value
        self.is_training = True

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'] +
                     ['this', 'signature'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        self.dstc2_acts = None

        if not domain:
            # Default to CamRest dimensions
            self.NStateFeatures = 56

            # Default to CamRest actions
            self.dstc2_acts = [
                'repeat', 'canthelp', 'affirm', 'negate', 'deny', 'ack',
                'thankyou', 'bye', 'reqmore', 'hello', 'welcomemsg',
                'expl-conf', 'select', 'offer', 'reqalts', 'confirm-domain',
                'confirm'
            ]
        else:
            # Try to identify number of state features
            if domain in ['SlotFilling', 'CamRest']:
                DState = \
                    SlotFillingDialogueState(
                        {'slots': self.system_requestable_slots})

                # Plato does not use action masks (rules to define which
                # actions are valid from each state) and so training can
                # be harder. This becomes easier if we have a smaller
                # action set.

                # Sub-case for CamRest
                if domain == 'CamRest':
                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_sys = [
                        'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye',
                        'reqmore', 'welcomemsg', 'expl-conf', 'select',
                        'repeat', 'confirm-domain', 'confirm'
                    ]

                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_usr = [
                        'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye',
                        'reqmore', 'hello', 'expl-conf', 'repeat', 'reqalts',
                        'restart', 'confirm'
                    ]

                    if self.agent_role == 'system':
                        self.dstc2_acts = self.dstc2_acts_sys

                    elif self.agent_role == 'user':
                        self.dstc2_acts = self.dstc2_acts_usr

            else:
                print('Warning! Domain has not been defined. Using '
                      'Slot-Filling Dialogue State')
                DState = \
                    SlotFillingDialogueState({'slots': self.informable_slots})

            DState.initialize()
            self.NStateFeatures = len(self.encode_state(DState))
            print('Supervised DialoguePolicy automatically determined number '
                  'of state features: {0}'.format(self.NStateFeatures))

        if domain == 'CamRest':
            self.NActions = len(self.dstc2_acts) + len(self.requestable_slots)

            if self.agent_role == 'system':
                self.NActions += len(self.system_requestable_slots)
            else:
                self.NActions += len(self.requestable_slots)
        else:
            self.NActions = 5

        self.policy_alpha = 0.05

        self.tf_saver = None

    def initialize(self, **kwargs):
        """
        Initialize internal structures at the beginning of each dialogue

        :return: Nothing
        """

        if self.agent_role == 'system':
            # Put your system expert policy here
            self.warmup_policy = HandcraftedPolicy(self.ontology)

        elif self.agent_role == 'user':
            usim_args = \
                dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))
            # Put your user expert policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        if 'is_training' in kwargs:
            self.is_training = bool(kwargs['is_training'])

            if self.agent_role == 'user' and self.warmup_simulator:
                if 'goal' in kwargs:
                    self.warmup_simulator.initialize({kwargs['goal']})
                else:
                    print('WARNING ! No goal provided for Supervised policy '
                          'user simulator @ initialize')
                    self.warmup_simulator.initialize({})

        if 'policy_path' in kwargs:
            self.policy_path = kwargs['policy_path']

        if 'learning_rate' in kwargs:
            self.policy_alpha = kwargs['learning_rate']

        if self.sess is None:
            self.policy_net = self.feed_forward_net_init()
            self.sess = tf.InteractiveSession()
            self.sess.run(tf.global_variables_initializer())

            self.tf_saver = \
                tf.train.Saver(var_list=tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES, scope=self.tf_scope))

    def restart(self, args):
        """
        Re-initialize relevant parameters / variables at the beginning of each
        dialogue.

        :return:
        """

        if self.agent_role == 'user' and self.warmup_simulator:
            if 'goal' in args:
                self.warmup_simulator.initialize(args)
            else:
                print('WARNING! No goal provided for Supervised policy user '
                      'simulator @ restart')
                self.warmup_simulator.initialize({})

    def next_action(self, state):
        """
        Consults the policy to produce the agent's response

        :param state: the current dialogue state
        :return: a list of dialogue acts, representing the agent's response
        """

        if self.is_training:
            # This is a Supervised DialoguePolicy, so no exploration here.

            if self.agent_role == 'system':
                return self.warmup_policy.next_action(state)
            else:
                self.warmup_simulator.receive_input(state.user_acts,
                                                    state.user_goal)
                return self.warmup_simulator.generate_output()

        pl_calculated, pl_state, pl_newvals, pl_optimizer, pl_loss = \
            self.policy_net

        obs_vector = np.expand_dims(self.encode_state(state), axis=0)

        probs = self.sess.run(pl_calculated, feed_dict={pl_state: obs_vector})

        if self.IS_GREEDY_POLICY:
            # Greedy policy: Return action with maximum value from the given
            # state
            sys_acts = \
                self.decode_action(
                    np.argmax(probs), self.agent_role == 'system')

        else:
            # Stochastic policy: Sample action wrt Q values
            if any(np.isnan(probs[0])):
                print('WARNING! Supervised DialoguePolicy: NAN detected in a'
                      'ction probabilities! Selecting random action.')
                return self.decode_action(
                    random.choice(range(0, self.NActions)),
                    self.agent_role == 'system')

            # Make sure weights are positive
            min_p = min(probs[0])

            if min_p < 0:
                positive_weights = [p + abs(min_p) for p in probs[0]]
            else:
                positive_weights = probs[0]

            # Normalize weights
            positive_weights /= sum(positive_weights)

            sys_acts = \
                self.decode_action(
                    random.choices(
                        [a for a in range(self.NActions)],
                        weights=positive_weights)[0],
                    self.agent_role == 'system')

        return sys_acts

    def feed_forward_net_init(self):
        """
        Initialize the feed forward network.

        :return: some useful variables
        """
        self.tf_scope = "policy_" + self.agent_role + '_' + str(self.agent_id)

        with tf.variable_scope(self.tf_scope):
            state = tf.placeholder("float", [None, self.NStateFeatures])
            newvals = tf.placeholder("float", [None, self.NActions])

            w1 = \
                tf.get_variable("w1",
                                [self.NStateFeatures, self.NStateFeatures])
            b1 = tf.get_variable("b1", [self.NStateFeatures])
            h1 = tf.nn.sigmoid(tf.matmul(state, w1) + b1)

            w2 = \
                tf.get_variable("w2",
                                [self.NStateFeatures, self.NStateFeatures])
            b2 = tf.get_variable("b2", [self.NStateFeatures])
            h2 = tf.nn.sigmoid(tf.matmul(h1, w2) + b2)

            w3 = tf.get_variable("w3", [self.NStateFeatures, self.NActions])
            b3 = tf.get_variable("b3", [self.NActions])

            calculated = tf.nn.softmax(tf.matmul(h2, w3) + b3)

            diffs = calculated - newvals
            loss = tf.nn.l2_loss(diffs)
            optimizer = \
                tf.train.AdamOptimizer(self.policy_alpha).minimize(loss)

            return calculated, state, newvals, optimizer, loss

    def train(self, dialogues):
        """
        Train the neural net policy model

        :param dialogues: dialogue experience
        :return: nothing
        """

        # If called by accident
        if not self.is_training:
            return

        pl_calculated, pl_state, pl_newvals, pl_optimizer, pl_loss =\
            self.policy_net

        states = []
        actions = []

        for dialogue in dialogues:
            for index, turn in enumerate(dialogue):
                act_enc = \
                    self.encode_action(turn['action'],
                                       self.agent_role == 'system')
                if act_enc > -1:
                    states.append(self.encode_state(turn['state']))
                    action = np.zeros(self.NActions)
                    action[act_enc] = 1
                    actions.append(action)

        # Train policy
        self.sess.run(pl_optimizer,
                      feed_dict={
                          pl_state: states,
                          pl_newvals: actions
                      })

    def encode_state(self, state):
        """
        Encodes the dialogue state into a vector.

        :param state: the state to encode
        :return: int - a unique state encoding
        """

        temp = [int(state.is_terminal_state)]

        temp.append(1) if state.system_made_offer else temp.append(0)

        # If the agent plays the role of the user it needs access to its own
        # goal
        if self.agent_role == 'user':
            # The user agent needs to know which constraints and requests
            # need to be communicated and which of them
            # actually have.
            if state.user_goal:
                for c in self.informable_slots:
                    if c != 'name':
                        if c in state.user_goal.constraints:
                            temp.append(1)
                        else:
                            temp.append(0)

                for c in self.informable_slots:
                    if c != 'name':
                        if c in state.user_goal.actual_constraints and \
                                state.user_goal.actual_constraints[c].value:
                            temp.append(1)
                        else:
                            temp.append(0)

                for r in self.requestable_slots:
                    if r in state.user_goal.requests:
                        temp.append(1)
                    else:
                        temp.append(0)

                for r in self.requestable_slots:
                    if r in state.user_goal.actual_requests and \
                            state.user_goal.actual_requests[r].value:
                        temp.append(1)
                    else:
                        temp.append(0)
            else:
                temp += [0] * 2 * (len(self.informable_slots) - 1 +
                                   len(self.requestable_slots))

        if self.agent_role == 'system':
            for value in state.slots_filled.values():
                # This contains the requested slot
                temp.append(1) if value else temp.append(0)

            for r in self.requestable_slots:
                temp.append(1) if r == state.requested_slot else temp.append(0)

        return temp

    def encode_action(self, actions, system=True):
        """
        Encode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be encoding another agent's action
        (e.g. a system encoding the previous user act).

        :param actions: actions to be encoded
        :param system: whether the role whose action we are encoding is a
                       'system'
        :return: the encoded action
        """

        if not actions:
            print('WARNING: Supervised DialoguePolicy action encoding called '
                  'with empty actions list (returning -1).')
            return -1

        action = actions[0]

        slot = None
        if action.params and action.params[0].slot:
            slot = action.params[0].slot

        if system:
            if self.dstc2_acts_sys and action.intent in self.dstc2_acts_sys:
                return self.dstc2_acts_sys.index(action.intent)

            if slot:
                if action.intent == 'request' and \
                        slot in self.system_requestable_slots:
                    return len(self.dstc2_acts_sys) + \
                           self.system_requestable_slots.index(slot)

                if action.intent == 'inform' and \
                        slot in self.requestable_slots:
                    return len(self.dstc2_acts_sys) + \
                           len(self.system_requestable_slots) + \
                           self.requestable_slots.index(slot)
        else:
            if self.dstc2_acts_usr and action.intent in self.dstc2_acts_usr:
                return self.dstc2_acts_usr.index(action.intent)

            if slot:
                if action.intent == 'request' and \
                        slot in self.requestable_slots:
                    return len(self.dstc2_acts_usr) + \
                           self.requestable_slots.index(slot)

                if action.intent == 'inform' and \
                        slot in self.requestable_slots:
                    return len(self.dstc2_acts_usr) + \
                           len(self.requestable_slots) + \
                           self.requestable_slots.index(slot)

        # Default fall-back action
        print('Supervised ({0}) policy action encoder warning: Selecting '
              'default action (unable to encode: {1})!'.format(
                  self.agent_role, action))
        return -1

    def decode_action(self, action_enc, system=True):
        """
        Decode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be decoding another agent's action
        (e.g. a system decoding the previous user act).

        :param action_enc: action encoding to be decoded
        :param system: whether the role whose action we are decoding is a
                       'system'
        :return: the decoded action
        """

        if system:
            if action_enc < len(self.dstc2_acts_sys):
                return [DialogueAct(self.dstc2_acts_sys[action_enc], [])]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.system_requestable_slots[
                                action_enc - len(self.dstc2_acts_sys)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots) +\
                    len(self.requestable_slots):
                index = action_enc - \
                        len(self.dstc2_acts_sys) - \
                        len(self.system_requestable_slots)
                return [
                    DialogueAct('inform', [
                        DialogueActItem(self.requestable_slots[index],
                                        Operator.EQ, '')
                    ])
                ]

        else:
            if action_enc < len(self.dstc2_acts_usr):
                return [DialogueAct(self.dstc2_acts_usr[action_enc], [])]

            if action_enc < len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.requestable_slots[action_enc -
                                                   len(self.dstc2_acts_usr)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots):
                return [
                    DialogueAct('inform', [
                        DialogueActItem(
                            self.requestable_slots[
                                action_enc - len(self.dstc2_acts_usr) -
                                len(self.requestable_slots)], Operator.EQ, '')
                    ])
                ]

    def save(self, path=None):
        """
        Saves the policy model to the provided path

        :param path: path to save the model to
        :return:
        """

        # Don't save if not training
        if not self.is_training:
            return

        print('DEBUG: {0} learning rate is: {1}'.format(
            self.agent_role, self.policy_alpha))

        pol_path = path

        if not pol_path:
            pol_path = self.policy_path

        if not pol_path:
            pol_path = 'Models/Policies/supervised_policy_' + \
                       self.agent_role + '_' + str(self.agent_id)

        if self.sess is not None and self.is_training:
            save_path = self.tf_saver.save(self.sess, pol_path)
            print('Supervised DialoguePolicy model saved at: %s' % save_path)

    def load(self, path):
        """
        Load the policy model from the provided path

        :param path: path to load the model from
        :return:
        """

        pol_path = path

        if not pol_path:
            pol_path = self.policy_path

        if not pol_path:
            pol_path = 'Models/Policies/supervised_policy_' + \
                       self.agent_role + '_' + str(self.agent_id)

        if os.path.isfile(pol_path + '.meta'):
            self.policy_net = self.feed_forward_net_init()
            self.sess = tf.InteractiveSession()

            self.tf_saver = \
                tf.train.Saver(
                    var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                               scope=self.tf_scope))

            self.tf_saver.restore(self.sess, pol_path)

            print('Supervised DialoguePolicy model loaded from {0}.'.format(
                pol_path))

        else:
            print('WARNING! Supervised DialoguePolicy cannot load policy '
                  'model from {0}!'.format(pol_path))
Exemplo n.º 12
0
class WoLFPHCPolicy(DialoguePolicy.DialoguePolicy):
    def __init__(self,
                 ontology,
                 database,
                 agent_id=0,
                 agent_role='system',
                 alpha=0.25,
                 gamma=0.95,
                 epsilon=0.25,
                 alpha_decay=0.9995,
                 epsilon_decay=0.995):
        """
        Initialize parameters and internal structures

        :param ontology: the domain's ontology
        :param database: the domain's database
        :param agent_id: the agent's id
        :param agent_role: the agent's role
        :param alpha: the learning rate
        :param gamma: the discount rate
        :param epsilon: the exploration rate
        :param alpha_decay: the learning rate discount rate
        :param epsilon_decay: the exploration rate discount rate
        """

        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.alpha_decay = alpha_decay
        self.epsilon_decay = epsilon_decay

        self.IS_GREEDY_POLICY = False

        # TODO: Put these as arguments in the config
        self.d_win = 0.0025
        self.d_lose = 0.01

        self.is_training = False

        self.agent_id = agent_id
        self.agent_role = agent_role

        self.ontology = None
        if isinstance(ontology, Ontology):
            self.ontology = ontology
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase):
            self.database = database
        else:
            raise ValueError('WoLF PHC DialoguePolicy: Unacceptable database '
                             'type %s ' % database)

        self.Q = {}
        self.pi = {}
        self.mean_pi = {}
        self.state_counter = {}

        self.pp = pprint.PrettyPrinter(width=160)  # For debug!

        # System and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        if self.agent_role == 'system':
            # Put your system expert policy here
            self.warmup_policy = \
                HandcraftedPolicy.HandcraftedPolicy(self.ontology)

        elif self.agent_role == 'user':
            usim_args = dict(
                zip(['ontology', 'database'], [self.ontology, self.database]))
            # Put your user expert policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        # Sub-case for CamRest
        self.dstc2_acts_sys = self.dstc2_acts_usr = None

        # Plato does not use action masks (rules to define which
        # actions are valid from each state) and so training can
        # be harder. This becomes easier if we have a smaller
        # action set.

        # Does not include inform and request that are modelled together with
        # their arguments
        self.dstc2_acts_sys = [
            'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye', 'reqmore',
            'welcomemsg', 'expl-conf', 'select', 'repeat', 'confirm-domain',
            'confirm'
        ]

        # Does not include inform and request that are modelled together with
        # their arguments
        self.dstc2_acts_usr = [
            'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye', 'reqmore',
            'hello', 'expl-conf', 'repeat', 'reqalts', 'restart', 'confirm'
        ]

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        if self.dstc2_acts_sys:
            if self.agent_role == 'system':
                # self.NActions = 5
                # self.NOtherActions = 4
                self.NActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

            elif self.agent_role == 'user':
                # self.NActions = 4
                # self.NOtherActions = 5
                self.NActions = \
                    len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots) +\
                    len(self.system_requestable_slots)

                self.NOtherActions = len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)
        else:
            if self.agent_role == 'system':
                self.NActions = \
                    5 + len(self.ontology.ontology['system_requestable']) + \
                    len(self.ontology.ontology['requestable'])
                self.NOtherActions = \
                    4 + 2 * len(self.ontology.ontology['requestable'])

            elif self.agent_role == 'user':
                self.NActions = \
                    4 + 2 * len(self.ontology.ontology['requestable'])
                self.NOtherActions = \
                    5 + len(self.ontology.ontology['system_requestable']) + \
                    len(self.ontology.ontology['requestable'])

        self.statistics = {'supervised_turns': 0, 'total_turns': 0}

    def initialize(self, **kwargs):
        """
        Initialize internal structures at the beginning of each dialogue

        :return: Nothing
        """

        if 'is_training' in kwargs:
            self.is_training = bool(kwargs['is_training'])

            if 'learning_rate' in kwargs:
                self.alpha = float(kwargs['learning_rate'])

            if 'learning_decay_rate' in kwargs:
                self.alpha_decay = float(kwargs['learning_decay_rate'])

            if 'exploration_rate' in kwargs:
                self.epsilon = float(kwargs['exploration_rate'])

            if 'exploration_decay_rate' in kwargs:
                self.epsilon_decay = float(kwargs['exploration_decay_rate'])

            if 'gamma' in kwargs:
                self.gamma = float(kwargs['gamma'])

            if self.agent_role == 'user' and self.warmup_simulator:
                if 'goal' in kwargs:
                    self.warmup_simulator.initialize({kwargs['goal']})
                else:
                    print('WARNING ! No goal provided for Supervised policy '
                          'user simulator @ initialize')
                    self.warmup_simulator.initialize({})

    def restart(self, args):
        """
        Re-initialize relevant parameters / variables at the beginning of each
        dialogue.

        :return: nothing
        """

        if self.agent_role == 'user' and self.warmup_simulator:
            if 'goal' in args:
                self.warmup_simulator.initialize(args)
            else:
                print('WARNING! No goal provided for Supervised policy user '
                      'simulator @ restart')
                self.warmup_simulator.initialize({})

    def next_action(self, state):
        """
        Consults the policy to produce the agent's response

        :param state: the current dialogue state
        :return: a list of dialogue acts, representing the agent's response
        """

        state_enc = self.encode_state(state)
        self.statistics['total_turns'] += 1

        if state_enc not in self.pi or \
                (self.is_training and random.random() < self.epsilon):
            if not self.is_training:
                if not self.pi:
                    print(f'\nWARNING! WoLF-PHC pi is empty '
                          f'({self.agent_role}). Did you load the correct '
                          f'file?\n')
                else:
                    print(f'\nWARNING! WoLF-PHC state not found in policy '
                          f'pi ({self.agent_role}).\n')

            if random.random() < 0.5:
                print('--- {0}: Selecting warmup action.'.format(
                    self.agent_role))
                self.statistics['supervised_turns'] += 1

                if self.agent_role == 'system':
                    return self.warmup_policy.next_action(state)

                else:
                    self.warmup_simulator.receive_input(
                        state.user_acts, state.user_goal)
                    return self.warmup_simulator.respond()
            else:
                print('--- {0}: Selecting random action.'.format(
                    self.agent_role))
                return self.decode_action(
                    random.choice(range(0, self.NActions)),
                    self.agent_role == 'system')

        if self.IS_GREEDY_POLICY:
            # Get greedy action
            max_pi = max(self.pi[state_enc][:-1])  # Do not consider 'UNK'
            maxima = \
                [i for i, j in enumerate(self.pi[state_enc]) if j == max_pi]

            # Break ties randomly
            if maxima:
                sys_acts = \
                    self.decode_action(random.choice(maxima),
                                       self.agent_role == 'system')
            else:
                print('--- {0}: Warning! No maximum value identified for '
                      'policy. Selecting random action.'.format(
                          self.agent_role))

                return self.decode_action(
                    random.choice(range(0, self.NActions)),
                    self.agent_role == 'system')
        else:
            # Sample next action
            sys_acts = \
                self.decode_action(
                    random.choices(range(len(self.pi[state_enc])),
                                   self.pi[state_enc])[0],
                    self.agent_role == 'system')

        return sys_acts

    def encode_state(self, state):
        """
        Encodes the dialogue state into an index used to address the Q matrix.

        :param state: the state to encode
        :return: int - a unique state encoding
        """

        temp = [int(state.is_terminal_state)]

        temp.append(1) if state.system_made_offer else temp.append(0)

        if self.agent_role == 'user':
            # The user agent needs to know which constraints and requests
            # need to be communicated and which of them
            # actually have.
            if state.user_goal:
                for c in self.informable_slots:
                    if c != 'name':
                        if c in state.user_goal.constraints and \
                                state.user_goal.constraints[c].value:
                            temp.append(1)
                        else:
                            temp.append(0)

                        if c in state.user_goal.actual_constraints and \
                                state.user_goal.actual_constraints[c].value:
                            temp.append(1)
                        else:
                            temp.append(0)

                for r in self.requestable_slots:
                    if r in state.user_goal.requests:
                        temp.append(1)
                    else:
                        temp.append(0)

                    if r in state.user_goal.actual_requests and \
                            state.user_goal.actual_requests[r].value:
                        temp.append(1)
                    else:
                        temp.append(0)

            else:
                temp += \
                    [0] * 2*(len(self.informable_slots)-1 +
                             len(self.requestable_slots))

        if self.agent_role == 'system':
            for value in state.slots_filled.values():
                # This contains the requested slot
                temp.append(1) if value else temp.append(0)

            for r in self.requestable_slots:
                temp.append(1) if r == state.requested_slot else temp.append(0)

        # Encode state
        state_enc = 0
        for t in temp:
            state_enc = (state_enc << 1) | t

        return state_enc

    def encode_action(self, actions, system=True):
        """
        Encode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be encoding another agent's action
        (e.g. a system encoding the previous user act).

        :param actions: actions to be encoded
        :param system: whether the role whose action we are encoding is a
                       'system'
        :return: the encoded action
        """

        # TODO: Handle multiple actions
        if not actions:
            print('WARNING: WoLF-PHC DialoguePolicy action encoding called '
                  'with empty actions list (returning -1).')
            return -1

        action = actions[0]

        if system:
            if self.dstc2_acts_sys and action.intent in self.dstc2_acts_sys:
                return self.dstc2_acts_sys.index(action.intent)

            if action.intent == 'request':
                return len(self.dstc2_acts_sys) + \
                       self.system_requestable_slots.index(
                           action.params[0].slot)

            if action.intent == 'inform':
                return len(self.dstc2_acts_sys) + \
                       len(self.system_requestable_slots) + \
                       self.requestable_slots.index(action.params[0].slot)
        else:
            if self.dstc2_acts_usr and action.intent in self.dstc2_acts_usr:
                return self.dstc2_acts_usr.index(action.intent)

            if action.intent == 'request':
                return len(self.dstc2_acts_usr) + \
                       self.requestable_slots.index(action.params[0].slot)

            if action.intent == 'inform':
                return len(self.dstc2_acts_usr) + \
                       len(self.requestable_slots) + \
                       self.system_requestable_slots.index(
                           action.params[0].slot)

        if (self.agent_role == 'system') == system:
            print('WoLF-PHC ({0}) policy action encoder warning: Selecting '
                  'default action (unable to encode: {1})!'.format(
                      self.agent_role, action))
        else:
            print('WoLF-PHC ({0}) policy action encoder warning: Selecting '
                  'default action (unable to encode other agent action: {1})!'.
                  format(self.agent_role, action))

        return -1

    def decode_action(self, action_enc, system=True):
        """
        Decode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be decoding another agent's action
        (e.g. a system decoding the previous user act).

        :param action_enc: action encoding to be decoded
        :param system: whether the role whose action we are decoding is a
                       'system'
        :return: the decoded action
        """

        if system:
            if action_enc < len(self.dstc2_acts_sys):
                return [DialogueAct(self.dstc2_acts_sys[action_enc], [])]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.system_requestable_slots[
                                action_enc - len(self.dstc2_acts_sys)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots) + \
                    len(self.requestable_slots):
                index = \
                    action_enc - len(self.dstc2_acts_sys) - \
                    len(self.system_requestable_slots)
                return [
                    DialogueAct('inform', [
                        DialogueActItem(self.requestable_slots[index],
                                        Operator.EQ, '')
                    ])
                ]

        else:
            if action_enc < len(self.dstc2_acts_usr):
                return [DialogueAct(self.dstc2_acts_usr[action_enc], [])]

            if action_enc < len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.requestable_slots[action_enc -
                                                   len(self.dstc2_acts_usr)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots):
                return [
                    DialogueAct('inform', [
                        DialogueActItem(
                            self.system_requestable_slots[
                                action_enc - len(self.dstc2_acts_usr) -
                                len(self.requestable_slots)], Operator.EQ, '')
                    ])
                ]

        # Default fall-back action
        print('WoLF-PHC DialoguePolicy ({0}) policy action decoder warning: '
              'Selecting repeat() action (index: {1})!'.format(
                  self.agent_role, action_enc))
        return [DialogueAct('repeat', [])]

    def train(self, dialogues):
        """
        Train the model using WoLF-PHC.

        :param dialogues: a list dialogues, which is a list of dialogue turns
                         (state, action, reward triplets).
        :return:
        """

        if not self.is_training:
            return

        for dialogue in dialogues:
            if len(dialogue) > 1:
                dialogue[-2]['reward'] = dialogue[-1]['reward']

            for turn in dialogue:
                state_enc = self.encode_state(turn['state'])
                new_state_enc = self.encode_state(turn['new_state'])
                action_enc = \
                    self.encode_action(
                        turn['action'],
                        self.agent_role == 'system')

                # Skip unrecognised actions
                if action_enc < 0 or turn['action'][0].intent == 'bye':
                    continue

                if state_enc not in self.Q:
                    self.Q[state_enc] = [0] * self.NActions

                if new_state_enc not in self.Q:
                    self.Q[new_state_enc] = [0] * self.NActions

                if state_enc not in self.pi:
                    self.pi[state_enc] = \
                        [float(1/self.NActions)] * self.NActions

                if state_enc not in self.mean_pi:
                    self.mean_pi[state_enc] = \
                        [float(1/self.NActions)] * self.NActions

                if state_enc not in self.state_counter:
                    self.state_counter[state_enc] = 1
                else:
                    self.state_counter[state_enc] += 1

                # Update Q
                self.Q[state_enc][action_enc] = \
                    ((1 - self.alpha) * self.Q[state_enc][action_enc]) + \
                    self.alpha * (
                            turn['reward'] +
                            (self.gamma * np.max(self.Q[new_state_enc])))

                # Update mean policy estimate
                for a in range(self.NActions):
                    self.mean_pi[state_enc][a] = \
                        self.mean_pi[state_enc][a] + \
                        ((1.0 / self.state_counter[state_enc]) *
                         (self.pi[state_enc][a] - self.mean_pi[state_enc][a]))

                # Determine delta
                sum_policy = 0.0
                sum_mean_policy = 0.0

                for a in range(self.NActions):
                    sum_policy = sum_policy + (self.pi[state_enc][a] *
                                               self.Q[state_enc][a])
                    sum_mean_policy = \
                        sum_mean_policy + \
                        (self.mean_pi[state_enc][a] * self.Q[state_enc][a])

                if sum_policy > sum_mean_policy:
                    delta = self.d_win
                else:
                    delta = self.d_lose

                # Update policy estimate
                max_Q_idx = np.argmax(self.Q[state_enc])

                d_plus = delta
                d_minus = ((-1.0) * d_plus) / (self.NActions - 1.0)

                for a in range(self.NActions):
                    if a == max_Q_idx:
                        self.pi[state_enc][a] = \
                            min(1.0, self.pi[state_enc][a] + d_plus)
                    else:
                        self.pi[state_enc][a] = \
                            max(0.0, self.pi[state_enc][a] + d_minus)

                # Constrain pi to a legal probability distribution
                sum_pi = sum(self.pi[state_enc])
                for a in range(self.NActions):
                    self.pi[state_enc][a] /= sum_pi

        # Decay learning rate after each episode
        if self.alpha > 0.001:
            self.alpha *= self.alpha_decay

        # Decay exploration rate after each episode
        if self.epsilon > 0.25:
            self.epsilon *= self.epsilon_decay

        print('[alpha: {0}, epsilon: {1}]'.format(self.alpha, self.epsilon))

    def save(self, path=None):
        """
        Saves the policy model to the path provided

        :param path: path to save the model to
        :return:
        """

        # Don't save if not training
        if not self.is_training:
            return

        if not path:
            path = 'Models/Policies/wolf_phc_policy.pkl'
            print('No policy file name provided. Using default: {0}'.format(
                path))

        obj = {
            'Q': self.Q,
            'pi': self.pi,
            'mean_pi': self.mean_pi,
            'state_counter': self.state_counter,
            'a': self.alpha,
            'e': self.epsilon,
            'g': self.gamma
        }

        with open(path, 'wb') as file:
            pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)

        if self.statistics['total_turns'] > 0:
            print('DEBUG > {0} WoLF PHC DialoguePolicy supervision ratio: {1}'.
                  format(
                      self.agent_role,
                      float(self.statistics['supervised_turns'] /
                            self.statistics['total_turns'])))

        print(f'DEBUG > {self.agent_role} WoLF PHC DialoguePolicy state space '
              f'size: {len(self.pi)}')

    def load(self, path=None):
        """
        Load the policy model from the path provided

        :param path: path to load the model from
        :return:
        """

        if not path:
            print('No policy loaded.')
            return

        if isinstance(path, str):
            if os.path.isfile(path):
                with open(path, 'rb') as file:
                    obj = pickle.load(file)

                    if 'Q' in obj:
                        self.Q = obj['Q']
                    if 'pi' in obj:
                        self.pi = obj['pi']
                    if 'mean_pi' in obj:
                        self.mean_pi = obj['mean_pi']
                    if 'state_counter' in obj:
                        self.state_counter = obj['state_counter']
                    if 'a' in obj:
                        self.alpha = obj['a']
                    if 'e' in obj:
                        self.epsilon = obj['e']
                    if 'g' in obj:
                        self.gamma = obj['g']

                    print('WoLF-PHC DialoguePolicy loaded from {0}.'.format(
                        path))

            else:
                print('Warning! WoLF-PHC DialoguePolicy file %s not found' %
                      path)
        else:
            print('Warning! Unacceptable value for WoLF-PHC policy file name:'
                  ' %s ' % path)