コード例 #1
0
    def initialize(self, **kwargs):
        """
        Initialize internal structures at the beginning of each dialogue

        :return: Nothing
        """

        if self.agent_role == 'system':
            # Put your system expert policy here
            self.warmup_policy = HandcraftedPolicy(self.ontology)

        elif self.agent_role == 'user':
            usim_args = \
                dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))
            # Put your user expert policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        if 'is_training' in kwargs:
            self.is_training = bool(kwargs['is_training'])

            if self.agent_role == 'user' and self.warmup_simulator:
                if 'goal' in kwargs:
                    self.warmup_simulator.initialize({kwargs['goal']})
                else:
                    print('WARNING ! No goal provided for Supervised policy '
                          'user simulator @ initialize')
                    self.warmup_simulator.initialize({})

        if 'policy_path' in kwargs:
            self.policy_path = kwargs['policy_path']

        if 'learning_rate' in kwargs:
            self.policy_alpha = kwargs['learning_rate']

        if self.sess is None:
            self.policy_net = self.feed_forward_net_init()
            self.sess = tf.InteractiveSession()
            self.sess.run(tf.global_variables_initializer())

            self.tf_saver = \
                tf.train.Saver(var_list=tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES, scope=self.tf_scope))
コード例 #2
0
    def __init__(self,
                 ontology,
                 database,
                 agent_id=0,
                 agent_role='system',
                 domain=None,
                 alpha=0.2,
                 epsilon=0.95,
                 gamma=0.95,
                 alpha_decay=0.995,
                 epsilon_decay=0.9995,
                 epsilon_min=0.05):
        """
        Initialize parameters and internal structures

        :param ontology: the domain's ontology
        :param database: the domain's database
        :param agent_id: the agent's id
        :param agent_role: the agent's role
        :param alpha: the learning rate
        :param gamma: the discount rate
        :param epsilon: the exploration rate
        :param alpha_decay: the learning rate discount rate
        :param epsilon_decay: the exploration rate discount rate
        """

        super(ReinforcePolicy, self).__init__()

        self.logger = logging.getLogger(__name__)

        self.agent_id = agent_id
        self.agent_role = agent_role

        self.IS_GREEDY = False

        self.ontology = None
        if isinstance(ontology, Ontology.Ontology):
            self.ontology = ontology
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase.DataBase):
            self.database = database
        else:
            raise ValueError('Reinforce DialoguePolicy: Unacceptable '
                             'database type %s ' % database)

        self.policy_path = None

        self.weights = None
        self.sess = None

        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.alpha_decay_rate = alpha_decay
        self.exploration_decay_rate = epsilon_decay
        self.epsilon_min = epsilon_min

        # System and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        if self.agent_role == 'system':
            # Put your system expert policy here
            self.warmup_policy = HandcraftedPolicy(self.ontology)

        elif self.agent_role == 'user':
            usim_args = \
                dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))
            # Put your user expert policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        self.tf_scope = "policy_" + self.agent_role + '_' + str(self.agent_id)

        # Default value
        self.is_training = True

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        if not domain:
            # Default to CamRest dimensions
            self.NStateFeatures = 56

            # Default to CamRest actions
            self.dstc2_acts = [
                'inform', 'offer', 'request', 'canthelp', 'affirm', 'negate',
                'deny', 'ack', 'thankyou', 'bye', 'reqmore', 'hello',
                'welcomemsg', 'expl-conf', 'select', 'repeat', 'reqalts',
                'confirm-domain', 'confirm'
            ]
        else:
            # Try to identify number of state features
            if domain in ['CamRest', 'SFH', 'SlotFilling']:
                d_state = \
                    SlotFillingDialogueState(
                        {'slots': self.system_requestable_slots})

                # Plato does not use action masks (rules to define which
                # actions are valid from each state) and so training can
                # be harder. This becomes easier if we have a smaller
                # action set.

                # Sub-case for CamRest
                if domain == 'CamRest':
                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_sys = [
                        'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye',
                        'reqmore', 'welcomemsg', 'expl-conf', 'select',
                        'repeat', 'confirm-domain', 'confirm'
                    ]

                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_usr = [
                        'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye',
                        'reqmore', 'hello', 'expl-conf', 'repeat', 'reqalts',
                        'restart', 'confirm'
                    ]

            else:
                self.logger.warning(
                    'Warning! Domain has not been defined. Using '
                    'Slot-Filling Dialogue State')
                d_state = \
                    SlotFillingDialogueState({'slots': self.informable_slots})

            d_state.initialize()
            self.NStateFeatures = len(self.encode_state(d_state))

            self.logger.info(
                'Reinforce DialoguePolicy {0} automatically determined '
                'number of state features: {1}'.format(self.agent_role,
                                                       self.NStateFeatures))

        if domain == 'CamRest' and self.dstc2_acts_sys:
            if self.agent_role == 'system':
                self.NActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots)

            elif self.agent_role == 'user':
                self.NActions = \
                    len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

        else:
            if self.agent_role == 'system':
                self.NActions = \
                    3 + len(self.system_requestable_slots) + \
                    len(self.requestable_slots)

                self.NOtherActions = \
                    2 + len(self.requestable_slots) +\
                    len(self.requestable_slots)

            elif self.agent_role == 'user':
                self.NActions = \
                    2 + len(self.requestable_slots) + \
                    len(self.requestable_slots)

                self.NOtherActions = \
                    3 + len(self.system_requestable_slots) + \
                    len(self.requestable_slots)

        self.logger.info(
            'Reinforce {0} DialoguePolicy Number of Actions: {1}'.format(
                self.agent_role, self.NActions))
コード例 #3
0
class ReinforcePolicy(DialoguePolicy.DialoguePolicy):
    def __init__(self,
                 ontology,
                 database,
                 agent_id=0,
                 agent_role='system',
                 domain=None,
                 alpha=0.2,
                 epsilon=0.95,
                 gamma=0.95,
                 alpha_decay=0.995,
                 epsilon_decay=0.9995,
                 epsilon_min=0.05):
        """
        Initialize parameters and internal structures

        :param ontology: the domain's ontology
        :param database: the domain's database
        :param agent_id: the agent's id
        :param agent_role: the agent's role
        :param alpha: the learning rate
        :param gamma: the discount rate
        :param epsilon: the exploration rate
        :param alpha_decay: the learning rate discount rate
        :param epsilon_decay: the exploration rate discount rate
        """

        super(ReinforcePolicy, self).__init__()

        self.logger = logging.getLogger(__name__)

        self.agent_id = agent_id
        self.agent_role = agent_role

        self.IS_GREEDY = False

        self.ontology = None
        if isinstance(ontology, Ontology.Ontology):
            self.ontology = ontology
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase.DataBase):
            self.database = database
        else:
            raise ValueError('Reinforce DialoguePolicy: Unacceptable '
                             'database type %s ' % database)

        self.policy_path = None

        self.weights = None
        self.sess = None

        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.alpha_decay_rate = alpha_decay
        self.exploration_decay_rate = epsilon_decay
        self.epsilon_min = epsilon_min

        # System and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        if self.agent_role == 'system':
            # Put your system expert policy here
            self.warmup_policy = HandcraftedPolicy(self.ontology)

        elif self.agent_role == 'user':
            usim_args = \
                dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))
            # Put your user expert policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        self.tf_scope = "policy_" + self.agent_role + '_' + str(self.agent_id)

        # Default value
        self.is_training = True

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        if not domain:
            # Default to CamRest dimensions
            self.NStateFeatures = 56

            # Default to CamRest actions
            self.dstc2_acts = [
                'inform', 'offer', 'request', 'canthelp', 'affirm', 'negate',
                'deny', 'ack', 'thankyou', 'bye', 'reqmore', 'hello',
                'welcomemsg', 'expl-conf', 'select', 'repeat', 'reqalts',
                'confirm-domain', 'confirm'
            ]
        else:
            # Try to identify number of state features
            if domain in ['CamRest', 'SFH', 'SlotFilling']:
                d_state = \
                    SlotFillingDialogueState(
                        {'slots': self.system_requestable_slots})

                # Plato does not use action masks (rules to define which
                # actions are valid from each state) and so training can
                # be harder. This becomes easier if we have a smaller
                # action set.

                # Sub-case for CamRest
                if domain == 'CamRest':
                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_sys = [
                        'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye',
                        'reqmore', 'welcomemsg', 'expl-conf', 'select',
                        'repeat', 'confirm-domain', 'confirm'
                    ]

                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_usr = [
                        'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye',
                        'reqmore', 'hello', 'expl-conf', 'repeat', 'reqalts',
                        'restart', 'confirm'
                    ]

            else:
                self.logger.warning(
                    'Warning! Domain has not been defined. Using '
                    'Slot-Filling Dialogue State')
                d_state = \
                    SlotFillingDialogueState({'slots': self.informable_slots})

            d_state.initialize()
            self.NStateFeatures = len(self.encode_state(d_state))

            self.logger.info(
                'Reinforce DialoguePolicy {0} automatically determined '
                'number of state features: {1}'.format(self.agent_role,
                                                       self.NStateFeatures))

        if domain == 'CamRest' and self.dstc2_acts_sys:
            if self.agent_role == 'system':
                self.NActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots)

            elif self.agent_role == 'user':
                self.NActions = \
                    len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

        else:
            if self.agent_role == 'system':
                self.NActions = \
                    3 + len(self.system_requestable_slots) + \
                    len(self.requestable_slots)

                self.NOtherActions = \
                    2 + len(self.requestable_slots) +\
                    len(self.requestable_slots)

            elif self.agent_role == 'user':
                self.NActions = \
                    2 + len(self.requestable_slots) + \
                    len(self.requestable_slots)

                self.NOtherActions = \
                    3 + len(self.system_requestable_slots) + \
                    len(self.requestable_slots)

        self.logger.info(
            'Reinforce {0} DialoguePolicy Number of Actions: {1}'.format(
                self.agent_role, self.NActions))

    def initialize(self, **kwargs):
        """
        Initialize internal structures at the beginning of each dialogue

        :return: Nothing
        """

        if 'is_training' in kwargs:
            self.is_training = bool(kwargs['is_training'])

            if self.agent_role == 'user' and self.warmup_simulator:
                if 'goal' in kwargs:
                    self.warmup_simulator.initialize({kwargs['goal']})
                else:
                    self.logger.warning(
                        'WARNING ! No goal provided for Reinforce policy '
                        'user simulator @ initialize')
                    self.warmup_simulator.initialize({})

        if 'policy_path' in kwargs:
            self.policy_path = kwargs['policy_path']

        if 'learning_rate' in kwargs:
            self.alpha = kwargs['learning_rate']

        if 'learning_decay_rate' in kwargs:
            self.alpha_decay_rate = kwargs['learning_decay_rate']

        if 'discount_factor' in kwargs:
            self.gamma = kwargs['discount_factor']

        if 'exploration_rate' in kwargs:
            self.alpha = kwargs['exploration_rate']

        if 'exploration_decay_rate' in kwargs:
            self.exploration_decay_rate = kwargs['exploration_decay_rate']

        if self.weights is None:
            self.weights = np.random.rand(self.NStateFeatures, self.NActions)

    def restart(self, args):
        """
        Re-initialize relevant parameters / variables at the beginning of each
        dialogue.

        :return: nothing
        """

        if self.agent_role == 'user' and self.warmup_simulator:
            if 'goal' in args:
                self.warmup_simulator.initialize(args)
            else:
                self.logger.warning(
                    'WARNING! No goal provided for Reinforce policy user '
                    'simulator @ restart')
                self.warmup_simulator.initialize({})

    def next_action(self, state):
        """
        Consults the policy to produce the agent's response

        :param state: the current dialogue state
        :return: a list of dialogue acts, representing the agent's response
        """

        if self.is_training and random.random() < self.epsilon:
            if random.random() < 0.75:
                self.logger.debug('--- {0}: Selecting warmup action.'.format(
                    self.agent_role))

                if self.agent_role == 'system':
                    return self.warmup_policy.next_action(state)

                else:
                    self.warmup_simulator.receive_input(
                        state.user_acts, state.user_goal)
                    return self.warmup_simulator.respond()

            else:
                self.logger.debug('--- {0}: Selecting random action.'.format(
                    self.agent_role))
                return self.decode_action(
                    random.choice(range(0, self.NActions)),
                    self.agent_role == "system")

        # Probabilistic policy: Sample from action wrt probabilities
        probs = self.calculate_policy(self.encode_state(state))

        if any(np.isnan(probs)):
            self.logger.warning(
                'WARNING! NAN detected in action probabilities! Selecting '
                'random action.')
            return self.decode_action(random.choice(range(0, self.NActions)),
                                      self.agent_role == "system")

        if self.IS_GREEDY:
            # Get greedy action
            max_pi = max(probs)
            maxima = [i for i, j in enumerate(probs) if j == max_pi]

            # Break ties randomly
            if maxima:
                sys_acts = \
                    self.decode_action(
                        random.choice(maxima), self.agent_role == 'system')
            else:
                self.logger.warning(
                    f'--- {self.agent_role}: Warning! No maximum value '
                    f'identified for policy. Selecting random action.')
                return self.decode_action(
                    random.choice(range(0, self.NActions)),
                    self.agent_role == 'system')
        else:
            # Pick from top 3 actions
            top_3 = np.argsort(-probs)[0:2]
            sys_acts = \
                self.decode_action(
                    random.choices(
                        top_3, probs[top_3])[0], self.agent_role == 'system')

        return sys_acts

    @staticmethod
    def softmax(x):
        """
        Calculates the softmax of x

        :param x: a number
        :return: the softmax of the number
        """
        e_x = np.exp(x - np.max(x))
        out = e_x / e_x.sum()
        return out

    @staticmethod
    def softmax_gradient(x):
        """
        Calculates the gradient of the softmax

        :param x: a number
        :return: the gradient of the softmax
        """
        x = np.asarray(x)
        x_reshaped = x.reshape(-1, 1)
        return np.diagflat(x_reshaped) - np.dot(x_reshaped, x_reshaped.T)

    def calculate_policy(self, state):
        """
        Calculates the probabilities for each action from the given state

        :param state: the current dialogue state
        :return: probabilities of actions
        """
        dot_prod = np.dot(state, self.weights)
        exp_dot_prod = np.exp(dot_prod)
        return exp_dot_prod / np.sum(exp_dot_prod)

    def train(self, dialogues):
        """
        Train the policy network

        :param dialogues: dialogue experience
        :return: nothing
        """
        # If called by accident
        if not self.is_training:
            return

        for dialogue in dialogues:
            discount = self.gamma

            if len(dialogue) > 1:
                dialogue[-2]['reward'] = dialogue[-1]['reward']

            rewards = [t['reward'] for t in dialogue]
            norm_rewards = \
                (rewards - np.mean(rewards)) / (np.std(rewards) + 0.000001)

            for (t, turn) in enumerate(dialogue):
                act_enc = self.encode_action(turn['action'],
                                             self.agent_role == 'system')
                if act_enc < 0:
                    continue

                state_enc = self.encode_state(turn['state'])

                if len(state_enc) != self.NStateFeatures:
                    raise ValueError(f'Reinforce DialoguePolicy '
                                     f'{self.agent_role} mismatch in state'
                                     f'dimensions: State Features: '
                                     f'{self.NStateFeatures} != State '
                                     f'Encoding Length: {len(state_enc)}')

                # Calculate the gradients

                # Call policy again to retrieve the probability of the
                # action taken
                probabilities = self.calculate_policy(state_enc)

                softmax_deriv = self.softmax_gradient(probabilities)[act_enc]
                log_policy_grad = softmax_deriv / probabilities[act_enc]
                gradient = \
                    np.asarray(
                        state_enc)[None, :].transpose().dot(
                        log_policy_grad[None, :])
                gradient = np.clip(gradient, -1.0, 1.0)

                # Train policy
                self.weights += \
                    self.alpha * gradient * norm_rewards[t] * discount
                self.weights = np.clip(self.weights, -1, 1)

                discount *= self.gamma

        if self.alpha > 0.01:
            self.alpha *= self.alpha_decay_rate

        self.decay_epsilon()

        self.logger.info(
            f'REINFORCE train, alpha: {self.alpha}, epsilon: {self.epsilon}')

    def decay_epsilon(self):
        """
        Decays epsilon (exploration rate) by epsilon decay.

         Decays epsilon (exploration rate) by epsilon decay.
         If epsilon is already less or equal compared to epsilon_min,
         the call of this method has no effect.

        :return:
        """
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.exploration_decay_rate

    def encode_state(self, state):
        """
        Encodes the dialogue state into a vector.

        :param state: the state to encode
        :return: int - a unique state encoding
        """

        temp = [int(state.is_terminal_state), int(state.system_made_offer)]

        if self.agent_role == 'user':
            # The user agent needs to know which constraints and requests
            # need to be communicated and which of them
            # actually have.
            if state.user_goal:
                for c in self.informable_slots:
                    if c != 'name':
                        if c in state.user_goal.constraints:
                            temp.append(1)
                        else:
                            temp.append(0)

                for c in self.informable_slots:
                    if c != 'name':
                        if c in state.user_goal.actual_constraints and \
                                state.user_goal.actual_constraints[c].value:
                            temp.append(1)
                        else:
                            temp.append(0)

                for r in self.requestable_slots:
                    if r in state.user_goal.requests:
                        temp.append(1)
                    else:
                        temp.append(0)

                for r in self.requestable_slots:

                    if r in state.user_goal.actual_requests and \
                            state.user_goal.actual_requests[r].value:
                        temp.append(1)
                    else:
                        temp.append(0)

            else:
                temp += [0] * 2 * (len(self.informable_slots) - 1 +
                                   len(self.requestable_slots))

        if self.agent_role == 'system':
            for value in state.slots_filled.values():
                # This contains the requested slot
                temp.append(1) if value else temp.append(0)

            for r in self.requestable_slots:
                temp.append(1) if r in state.requested_slots else temp.append(
                    0)

        return temp

    def encode_action(self, actions, system=True):
        """
        Encode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be encoding another agent's action
        (e.g. a system encoding the previous user act).

        :param actions: actions to be encoded
        :param system: whether the role whose action we are encoding is a
                       'system'
        :return: the encoded action
        """

        # TODO: Handle multiple actions
        # TODO: Action encoding in a principled way
        if not actions:
            self.logger.warning(
                'WARNING: Reinforce DialoguePolicy action encoding called '
                'with empty actions list (returning 0).')
            return -1

        action = actions[0]

        if system:
            if self.dstc2_acts_sys and action.intent in self.dstc2_acts_sys:
                return self.dstc2_acts_sys.index(action.intent)

            if action.intent == 'request':
                return len(self.dstc2_acts_sys) + \
                       self.system_requestable_slots.index(
                           action.params[0].slot)

            if action.intent == 'inform':
                return len(self.dstc2_acts_sys) + \
                       len(self.system_requestable_slots) + \
                       self.requestable_slots.index(action.params[0].slot)
        else:
            if self.dstc2_acts_usr and action.intent in self.dstc2_acts_usr:
                return self.dstc2_acts_usr.index(action.intent)

            if action.intent == 'request':
                return len(self.dstc2_acts_usr) + \
                       self.requestable_slots.index(action.params[0].slot)

            if action.intent == 'inform':
                return len(self.dstc2_acts_usr) + \
                       len(self.requestable_slots) + \
                       self.system_requestable_slots.index(action.params[0].slot)

        # Default fall-back action
        self.logger.warning(
            'Reinforce ({0}) policy action encoder warning: Selecting '
            'default action (unable to encode: {1})!'.format(
                self.agent_role, action))
        return -1

    def decode_action(self, action_enc, system=True):
        """
        Decode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be decoding another agent's action
        (e.g. a system decoding the previous user act).

        :param action_enc: action encoding to be decoded
        :param system: whether the role whose action we are decoding is a
                       'system'
        :return: the decoded action
        """

        if system:
            if action_enc < len(self.dstc2_acts_sys):
                return [DialogueAct(self.dstc2_acts_sys[action_enc], [])]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.system_requestable_slots[
                                action_enc - len(self.dstc2_acts_sys)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots) + \
                    len(self.requestable_slots):
                index = action_enc - len(self.dstc2_acts_sys) - \
                        len(self.system_requestable_slots)
                return [
                    DialogueAct('inform', [
                        DialogueActItem(self.requestable_slots[index],
                                        Operator.EQ, '')
                    ])
                ]

        else:
            if action_enc < len(self.dstc2_acts_usr):
                return [DialogueAct(self.dstc2_acts_usr[action_enc], [])]

            if action_enc < len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.requestable_slots[action_enc -
                                                   len(self.dstc2_acts_usr)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots):
                return [
                    DialogueAct('inform', [
                        DialogueActItem(
                            self.system_requestable_slots[
                                action_enc - len(self.dstc2_acts_usr) -
                                len(self.requestable_slots)], Operator.EQ, '')
                    ])
                ]

        # Default fall-back action
        self.logger.warning(
            'Reinforce DialoguePolicy ({0}) policy action decoder warning: '
            'Selecting default action (index: {1})!'.format(
                self.agent_role, action_enc))
        return [DialogueAct('bye', [])]

    def save(self, path=None):
        """
        Saves the policy model to the provided path

        :param path: path to save the model to
        :return:
        """

        # Don't save if not training
        if not self.is_training:
            return

        if not path:
            path = 'Models/Policies/reinforce.pkl'
            self.logger.warning(
                'No policy file name provided. Using default: {0}'.format(
                    path))

        obj = {
            'weights': self.weights,
            'alpha': self.alpha,
            'alpha_decay_rate': self.alpha_decay_rate,
            'epsilon': self.epsilon,
            'exploration_decay_rate': self.exploration_decay_rate,
            'epsilon_min': self.epsilon_min
        }

        with open(path, 'wb') as file:
            pickle.dump(obj, file, pickle.HIGHEST_PROTOCOL)

    def load(self, path=None):
        """
        Load the policy model from the provided path

        :param path: path to load the model from
        :return:
        """

        if not path:
            self.logger.warning('No policy loaded.')
            return

        if isinstance(path, str):
            if os.path.isfile(path):
                with open(path, 'rb') as file:
                    obj = pickle.load(file)

                    if 'weights' in obj:
                        self.weights = obj['weights']

                    if 'alpha' in obj:
                        self.alpha = obj['alpha']

                    if 'alpha_decay_rate' in obj:
                        self.alpha_decay_rate = obj['alpha_decay_rate']

                    if 'epsilon' in obj:
                        self.epsilon = obj['epsilon']

                    if 'exploration_decay_rate' in obj:
                        self.exploration_decay_rate = \
                            obj['exploration_decay_rate']

                    if 'epsilon_min' in obj:
                        self.epsilon_min = obj['epsilon_min']

                    self.logger.info(
                        'Reinforce DialoguePolicy loaded from {0}.'.format(
                            path))

            else:
                self.logger.warning(
                    'Warning! Reinforce DialoguePolicy file %s not found' %
                    path)
        else:
            self.logger.warning(
                'Warning! Unacceptable value for Reinforce DialoguePolicy '
                'file name: %s ' % path)
コード例 #4
0
    def __init__(self, args):
        """
        Parses the arguments in the dictionary and initializes the appropriate
        models for Dialogue State Tracking and Dialogue Policy.

        :param args: the configuration file parsed into a dictionary
        """

        if 'settings' not in args:
            raise AttributeError(
                'DialogueManager: Please provide settings (config)!')
        if 'ontology' not in args:
            raise AttributeError('DialogueManager: Please provide ontology!')
        if 'database' not in args:
            raise AttributeError('DialogueManager: Please provide database!')
        if 'domain' not in args:
            raise AttributeError('DialogueManager: Please provide domain!')

        settings = args['settings']
        ontology = args['ontology']
        database = args['database']
        domain = args['domain']

        agent_id = 0
        if 'agent_id' in args:
            agent_id = int(args['agent_id'])

        agent_role = 'system'
        if 'agent_role' in args:
            agent_role = args['agent_role']

        self.settings = settings

        self.TRAIN_DST = False
        self.TRAIN_POLICY = False

        self.MAX_DB_RESULTS = 10

        self.DSTracker = None
        self.policy = None
        self.policy_path = None
        self.ontology = None
        self.database = None
        self.domain = None

        self.agent_id = agent_id
        self.agent_role = agent_role

        self.dialogue_counter = 0
        self.CALCULATE_SLOT_ENTROPIES = True

        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        if isinstance(database, DataBase):
            self.database = database

        elif isinstance(database, str):
            if database[-3:] == '.db':
                self.database = SQLDataBase(database)
            elif database[-5:] == '.json':
                self.database = JSONDataBase(database)
            else:
                raise ValueError('Unacceptable database type %s ' % database)

        else:
            raise ValueError('Unacceptable database type %s ' % database)

        if args and args['policy']:
            if 'domain' in self.settings['DIALOGUE']:
                self.domain = self.settings['DIALOGUE']['domain']
            else:
                raise ValueError(
                    'Domain is not specified in DIALOGUE at config.')

            if 'calculate_slot_entropies' in args:
                self.CALCULATE_SLOT_ENTROPIES = \
                    bool(args['calculate_slot_entropies'])

            if args['policy']['type'] == 'handcrafted':
                self.policy = HandcraftedPolicy(self.ontology)

            elif args['policy']['type'] == 'q_learning':
                alpha = None
                if 'learning_rate' in args['policy']:
                    alpha = float(args['policy']['learning_rate'])

                gamma = None
                if 'discount_factor' in args['policy']:
                    gamma = float(args['policy']['discount_factor'])

                epsilon = None
                if 'exploration_rate' in args['policy']:
                    epsilon = float(args['policy']['exploration_rate'])

                alpha_decay = None
                if 'learning_decay_rate' in args['policy']:
                    alpha_decay = float(args['policy']['learning_decay_rate'])

                epsilon_decay = None
                if 'exploration_decay_rate' in args['policy']:
                    epsilon_decay = \
                        float(args['policy']['exploration_decay_rate'])

                self.policy = \
                    QPolicy(self.ontology,
                            self.database,
                            self.agent_id,
                            self.agent_role,
                            self.domain,
                            alpha=alpha,
                            epsilon=epsilon,
                            gamma=gamma,
                            alpha_decay=alpha_decay,
                            epsilon_decay=epsilon_decay)

            elif args['policy']['type'] == 'minimax_q':
                alpha = 0.25
                gamma = 0.95
                epsilon = 0.25
                alpha_decay = 0.9995
                epsilon_decay = 0.995

                if 'learning_rate' in args['policy']:
                    alpha = float(args['policy']['learning_rate'])

                if 'discount_factor' in args['policy']:
                    gamma = float(args['policy']['discount_factor'])

                if 'exploration_rate' in args['policy']:
                    epsilon = float(args['policy']['exploration_rate'])

                if 'learning_decay_rate' in args['policy']:
                    alpha_decay = float(args['policy']['learning_decay_rate'])

                if 'exploration_decay_rate' in args['policy']:
                    epsilon_decay = \
                        float(args['policy']['exploration_decay_rate'])

                self.policy = \
                    MinimaxQPolicy(
                        self.ontology,
                        self.database,
                        self.agent_id,
                        self.agent_role,
                        alpha=alpha,
                        epsilon=epsilon,
                        gamma=gamma,
                        alpha_decay=alpha_decay,
                        epsilon_decay=epsilon_decay)

            elif args['policy']['type'] == 'wolf_phc':
                alpha = 0.25
                gamma = 0.95
                epsilon = 0.25
                alpha_decay = 0.9995
                epsilon_decay = 0.995

                if 'learning_rate' in args['policy']:
                    alpha = float(args['policy']['learning_rate'])

                if 'discount_factor' in args['policy']:
                    gamma = float(args['policy']['discount_factor'])

                if 'exploration_rate' in args['policy']:
                    epsilon = float(args['policy']['exploration_rate'])

                if 'learning_decay_rate' in args['policy']:
                    alpha_decay = float(args['policy']['learning_decay_rate'])

                if 'exploration_decay_rate' in args['policy']:
                    epsilon_decay = \
                        float(args['policy']['exploration_decay_rate'])

                self.policy = \
                    WoLFPHCPolicy(
                        self.ontology,
                        self.database,
                        self.agent_id,
                        self.agent_role,
                        alpha=alpha,
                        epsilon=epsilon,
                        gamma=gamma,
                        alpha_decay=alpha_decay,
                        epsilon_decay=epsilon_decay)

            elif args['policy']['type'] == 'reinforce':
                alpha = None
                if 'learning_rate' in args['policy']:
                    alpha = float(args['policy']['learning_rate'])

                gamma = None
                if 'discount_factor' in args['policy']:
                    gamma = float(args['policy']['discount_factor'])

                epsilon = None
                if 'exploration_rate' in args['policy']:
                    epsilon = float(args['policy']['exploration_rate'])

                alpha_decay = None
                if 'learning_decay_rate' in args['policy']:
                    alpha_decay = float(args['policy']['learning_decay_rate'])

                epsilon_decay = None
                if 'exploration_decay_rate' in args['policy']:
                    epsilon_decay = \
                        float(args['policy']['exploration_decay_rate'])

                self.policy = \
                    ReinforcePolicy(
                        self.ontology,
                        self.database,
                        self.agent_id,
                        self.agent_role,
                        self.domain,
                        alpha=alpha,
                        epsilon=epsilon,
                        gamma=gamma,
                        alpha_decay=alpha_decay,
                        epsilon_decay=epsilon_decay)

            elif args['policy']['type'] == 'calculated':
                self.policy = \
                    CalculatedPolicy(
                        self.ontology,
                        self.database,
                        self.agent_id,
                        self.agent_role,
                        self.domain)

            elif args['policy']['type'] == 'supervised':
                self.policy = \
                    SupervisedPolicy(
                        self.ontology,
                        self.database,
                        self.agent_id,
                        self.agent_role,
                        self.domain)

            elif args['policy']['type'] == 'ludwig':
                if args['policy']['policy_path']:
                    print('DialogueManager: Instantiate your ludwig-based'
                          'policy here')
                else:
                    raise ValueError(
                        'Cannot find policy_path in the config for dialogue '
                        'policy.')
            else:
                raise ValueError(
                    'DialogueManager: Unsupported policy type!'.format(
                        args['policy']['type']))

            if 'train' in args['policy']:
                self.TRAIN_POLICY = bool(args['policy']['train'])

            if 'policy_path' in args['policy']:
                self.policy_path = args['policy']['policy_path']

        # DST Settings
        if 'DST' in args and args['DST']['dst']:
            if args['DST']['dst'] == 'CamRest':
                if args['DST']['policy']['model_path'] and \
                        args['DST']['policy']['metadata_path']:
                    self.DSTracker = \
                        CamRestLudwigDST(
                            {'model_path': args[
                                'DST']['policy']['model_path']})
                else:
                    raise ValueError(
                        'Cannot find model_path or metadata_path in the '
                        'config for dialogue state tracker.')

        # Default to dummy DST
        if not self.DSTracker:
            dst_args = dict(
                zip(['ontology', 'database', 'domain'],
                    [self.ontology, self.database, domain]))
            self.DSTracker = DummyStateTracker(dst_args)

        self.load('')
コード例 #5
0
    def init_policy(self, args):
        if not args or not args['policy']:
            # Early return
            return

        # collect all (potential) parameters
        # read float number parameters from policy section in config
        # Actually we not only read floats, but also translate between different namings of paramenters
        #  in config files and code.
        float_params_key_map = {
            'learning_rate': 'alpha',
            'discount_factor': 'gamma',
            'exploration_rate': 'epsilon',
            'learning_decay_rate': 'alpha_decay',
            'exploration_decay_rate': 'epsilon_decay',
            'min_exploration_rate': 'epsilon_min'
        }
        policy_params = {
            float_params_key_map[k]: float(v)
            for k, v in args['policy'].items() if k in float_params_key_map
        }

        # Read parameters from policy section which have equal names in config files and code.
        policy_params.update({
            k: v
            for k, v in args['policy'].items() if k not in float_params_key_map
        })

        # initialize the policy (depending on the configured policy type)
        if args['policy']['type'] == 'handcrafted':
            self.policy = HandcraftedPolicy(self.ontology)

        elif args['policy']['type'] == 'q_learning':

            self.policy = \
                QPolicy(self.ontology,
                        self.database,
                        self.agent_id,
                        self.agent_role,
                        self.domain,
                        print_level=self.print_level,
                        **policy_params)

        elif args['policy']['type'] == 'pytorch_reinforce':

            self.policy = \
                PyTorchReinforcePolicy(self.ontology,
                        self.database,
                        self.agent_id,
                        self.agent_role,
                        self.domain,
                        print_level=self.print_level,
                        **policy_params)

        elif args['policy']['type'] == 'pytorch_a2c':

            self.policy = \
                PyTorchA2CPolicy(self.ontology,
                        self.database,
                        self.agent_id,
                        self.agent_role,
                        self.domain,
                        print_level=self.print_level,
                        **policy_params)

        elif args['policy']['type'] == 'minimax_q':

            self.policy = \
                MinimaxQPolicy(
                    self.ontology,
                    self.database,
                    self.agent_id,
                    self.agent_role,
                    **policy_params)

        elif args['policy']['type'] == 'wolf_phc':

            self.policy = \
                WoLFPHCPolicy(
                    self.ontology,
                    self.database,
                    self.agent_id,
                    self.agent_role,
                    **policy_params)

        elif args['policy']['type'] == 'reinforce':

            self.policy = \
                ReinforcePolicy(
                    self.ontology,
                    self.database,
                    self.agent_id,
                    self.agent_role,
                    self.domain,
                    **policy_params)

        elif args['policy']['type'] == 'calculated':
            self.policy = \
                CalculatedPolicy(
                    self.ontology,
                    self.database,
                    self.agent_id,
                    self.agent_role,
                    self.domain)

        elif args['policy']['type'] == 'supervised':
            self.policy = \
                SupervisedPolicy(
                    self.ontology,
                    self.database,
                    self.agent_id,
                    self.agent_role,
                    self.domain)

        elif args['policy']['type'] == 'ludwig':
            if args['policy']['policy_path']:
                print('DialogueManager: Instantiate your ludwig-based'
                      'policy here')
            else:
                raise ValueError(
                    'Cannot find policy_path in the config for dialogue '
                    'policy.')
        else:
            raise ValueError(
                'DialogueManager: Unsupported policy type!'.format(
                    args['policy']['type']))

        if 'train' in args['policy']:
            self.TRAIN_POLICY = bool(args['policy']['train'])

        if 'policy_path' in args['policy']:
            self.policy_path = args['policy']['policy_path']
コード例 #6
0
class SupervisedPolicy(DialoguePolicy.DialoguePolicy):
    def __init__(self,
                 ontology,
                 database,
                 agent_id=0,
                 agent_role='system',
                 domain=None):
        """
        Initialize parameters and internal structures

        :param ontology: the domain's ontology
        :param database: the domain's database
        :param agent_id: the agent's id
        :param agent_role: the agent's role
        :param domain: the dialogue's domain
        """
        super(SupervisedPolicy, self).__init__()

        self.agent_id = agent_id
        self.agent_role = agent_role

        # True for greedy, False for stochastic
        self.IS_GREEDY_POLICY = False

        self.ontology = None
        if isinstance(ontology, Ontology.Ontology):
            self.ontology = ontology
        else:
            raise ValueError('Supervised DialoguePolicy: Unacceptable '
                             'ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase.DataBase):
            self.database = database
        else:
            raise ValueError('Supervised DialoguePolicy: Unacceptable '
                             'database type %s ' % database)

        self.policy_path = None

        self.policy_net = None
        self.tf_scope = "policy_" + self.agent_role + '_' + str(self.agent_id)
        self.sess = None

        # The system and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        # Default value
        self.is_training = True

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'] +
                     ['this', 'signature'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        self.dstc2_acts = None

        if not domain:
            # Default to CamRest dimensions
            self.NStateFeatures = 56

            # Default to CamRest actions
            self.dstc2_acts = [
                'repeat', 'canthelp', 'affirm', 'negate', 'deny', 'ack',
                'thankyou', 'bye', 'reqmore', 'hello', 'welcomemsg',
                'expl-conf', 'select', 'offer', 'reqalts', 'confirm-domain',
                'confirm'
            ]
        else:
            # Try to identify number of state features
            if domain in ['SlotFilling', 'CamRest']:
                DState = \
                    SlotFillingDialogueState(
                        {'slots': self.system_requestable_slots})

                # Plato does not use action masks (rules to define which
                # actions are valid from each state) and so training can
                # be harder. This becomes easier if we have a smaller
                # action set.

                # Sub-case for CamRest
                if domain == 'CamRest':
                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_sys = [
                        'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye',
                        'reqmore', 'welcomemsg', 'expl-conf', 'select',
                        'repeat', 'confirm-domain', 'confirm'
                    ]

                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_usr = [
                        'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye',
                        'reqmore', 'hello', 'expl-conf', 'repeat', 'reqalts',
                        'restart', 'confirm'
                    ]

                    if self.agent_role == 'system':
                        self.dstc2_acts = self.dstc2_acts_sys

                    elif self.agent_role == 'user':
                        self.dstc2_acts = self.dstc2_acts_usr

            else:
                print('Warning! Domain has not been defined. Using '
                      'Slot-Filling Dialogue State')
                DState = \
                    SlotFillingDialogueState({'slots': self.informable_slots})

            DState.initialize()
            self.NStateFeatures = len(self.encode_state(DState))
            print('Supervised DialoguePolicy automatically determined number '
                  'of state features: {0}'.format(self.NStateFeatures))

        if domain == 'CamRest':
            self.NActions = len(self.dstc2_acts) + len(self.requestable_slots)

            if self.agent_role == 'system':
                self.NActions += len(self.system_requestable_slots)
            else:
                self.NActions += len(self.requestable_slots)
        else:
            self.NActions = 5

        self.policy_alpha = 0.05

        self.tf_saver = None

    def initialize(self, **kwargs):
        """
        Initialize internal structures at the beginning of each dialogue

        :return: Nothing
        """

        if self.agent_role == 'system':
            # Put your system expert policy here
            self.warmup_policy = HandcraftedPolicy(self.ontology)

        elif self.agent_role == 'user':
            usim_args = \
                dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))
            # Put your user expert policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        if 'is_training' in kwargs:
            self.is_training = bool(kwargs['is_training'])

            if self.agent_role == 'user' and self.warmup_simulator:
                if 'goal' in kwargs:
                    self.warmup_simulator.initialize({kwargs['goal']})
                else:
                    print('WARNING ! No goal provided for Supervised policy '
                          'user simulator @ initialize')
                    self.warmup_simulator.initialize({})

        if 'policy_path' in kwargs:
            self.policy_path = kwargs['policy_path']

        if 'learning_rate' in kwargs:
            self.policy_alpha = kwargs['learning_rate']

        if self.sess is None:
            self.policy_net = self.feed_forward_net_init()
            self.sess = tf.InteractiveSession()
            self.sess.run(tf.global_variables_initializer())

            self.tf_saver = \
                tf.train.Saver(var_list=tf.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES, scope=self.tf_scope))

    def restart(self, args):
        """
        Re-initialize relevant parameters / variables at the beginning of each
        dialogue.

        :return:
        """

        if self.agent_role == 'user' and self.warmup_simulator:
            if 'goal' in args:
                self.warmup_simulator.initialize(args)
            else:
                print('WARNING! No goal provided for Supervised policy user '
                      'simulator @ restart')
                self.warmup_simulator.initialize({})

    def next_action(self, state):
        """
        Consults the policy to produce the agent's response

        :param state: the current dialogue state
        :return: a list of dialogue acts, representing the agent's response
        """

        if self.is_training:
            # This is a Supervised DialoguePolicy, so no exploration here.

            if self.agent_role == 'system':
                return self.warmup_policy.next_action(state)
            else:
                self.warmup_simulator.receive_input(state.user_acts,
                                                    state.user_goal)
                return self.warmup_simulator.generate_output()

        pl_calculated, pl_state, pl_newvals, pl_optimizer, pl_loss = \
            self.policy_net

        obs_vector = np.expand_dims(self.encode_state(state), axis=0)

        probs = self.sess.run(pl_calculated, feed_dict={pl_state: obs_vector})

        if self.IS_GREEDY_POLICY:
            # Greedy policy: Return action with maximum value from the given
            # state
            sys_acts = \
                self.decode_action(
                    np.argmax(probs), self.agent_role == 'system')

        else:
            # Stochastic policy: Sample action wrt Q values
            if any(np.isnan(probs[0])):
                print('WARNING! Supervised DialoguePolicy: NAN detected in a'
                      'ction probabilities! Selecting random action.')
                return self.decode_action(
                    random.choice(range(0, self.NActions)),
                    self.agent_role == 'system')

            # Make sure weights are positive
            min_p = min(probs[0])

            if min_p < 0:
                positive_weights = [p + abs(min_p) for p in probs[0]]
            else:
                positive_weights = probs[0]

            # Normalize weights
            positive_weights /= sum(positive_weights)

            sys_acts = \
                self.decode_action(
                    random.choices(
                        [a for a in range(self.NActions)],
                        weights=positive_weights)[0],
                    self.agent_role == 'system')

        return sys_acts

    def feed_forward_net_init(self):
        """
        Initialize the feed forward network.

        :return: some useful variables
        """
        self.tf_scope = "policy_" + self.agent_role + '_' + str(self.agent_id)

        with tf.variable_scope(self.tf_scope):
            state = tf.placeholder("float", [None, self.NStateFeatures])
            newvals = tf.placeholder("float", [None, self.NActions])

            w1 = \
                tf.get_variable("w1",
                                [self.NStateFeatures, self.NStateFeatures])
            b1 = tf.get_variable("b1", [self.NStateFeatures])
            h1 = tf.nn.sigmoid(tf.matmul(state, w1) + b1)

            w2 = \
                tf.get_variable("w2",
                                [self.NStateFeatures, self.NStateFeatures])
            b2 = tf.get_variable("b2", [self.NStateFeatures])
            h2 = tf.nn.sigmoid(tf.matmul(h1, w2) + b2)

            w3 = tf.get_variable("w3", [self.NStateFeatures, self.NActions])
            b3 = tf.get_variable("b3", [self.NActions])

            calculated = tf.nn.softmax(tf.matmul(h2, w3) + b3)

            diffs = calculated - newvals
            loss = tf.nn.l2_loss(diffs)
            optimizer = \
                tf.train.AdamOptimizer(self.policy_alpha).minimize(loss)

            return calculated, state, newvals, optimizer, loss

    def train(self, dialogues):
        """
        Train the neural net policy model

        :param dialogues: dialogue experience
        :return: nothing
        """

        # If called by accident
        if not self.is_training:
            return

        pl_calculated, pl_state, pl_newvals, pl_optimizer, pl_loss =\
            self.policy_net

        states = []
        actions = []

        for dialogue in dialogues:
            for index, turn in enumerate(dialogue):
                act_enc = \
                    self.encode_action(turn['action'],
                                       self.agent_role == 'system')
                if act_enc > -1:
                    states.append(self.encode_state(turn['state']))
                    action = np.zeros(self.NActions)
                    action[act_enc] = 1
                    actions.append(action)

        # Train policy
        self.sess.run(pl_optimizer,
                      feed_dict={
                          pl_state: states,
                          pl_newvals: actions
                      })

    def encode_state(self, state):
        """
        Encodes the dialogue state into a vector.

        :param state: the state to encode
        :return: int - a unique state encoding
        """

        temp = [int(state.is_terminal_state)]

        temp.append(1) if state.system_made_offer else temp.append(0)

        # If the agent plays the role of the user it needs access to its own
        # goal
        if self.agent_role == 'user':
            # The user agent needs to know which constraints and requests
            # need to be communicated and which of them
            # actually have.
            if state.user_goal:
                for c in self.informable_slots:
                    if c != 'name':
                        if c in state.user_goal.constraints:
                            temp.append(1)
                        else:
                            temp.append(0)

                for c in self.informable_slots:
                    if c != 'name':
                        if c in state.user_goal.actual_constraints and \
                                state.user_goal.actual_constraints[c].value:
                            temp.append(1)
                        else:
                            temp.append(0)

                for r in self.requestable_slots:
                    if r in state.user_goal.requests:
                        temp.append(1)
                    else:
                        temp.append(0)

                for r in self.requestable_slots:
                    if r in state.user_goal.actual_requests and \
                            state.user_goal.actual_requests[r].value:
                        temp.append(1)
                    else:
                        temp.append(0)
            else:
                temp += [0] * 2 * (len(self.informable_slots) - 1 +
                                   len(self.requestable_slots))

        if self.agent_role == 'system':
            for value in state.slots_filled.values():
                # This contains the requested slot
                temp.append(1) if value else temp.append(0)

            for r in self.requestable_slots:
                temp.append(1) if r == state.requested_slot else temp.append(0)

        return temp

    def encode_action(self, actions, system=True):
        """
        Encode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be encoding another agent's action
        (e.g. a system encoding the previous user act).

        :param actions: actions to be encoded
        :param system: whether the role whose action we are encoding is a
                       'system'
        :return: the encoded action
        """

        if not actions:
            print('WARNING: Supervised DialoguePolicy action encoding called '
                  'with empty actions list (returning -1).')
            return -1

        action = actions[0]

        slot = None
        if action.params and action.params[0].slot:
            slot = action.params[0].slot

        if system:
            if self.dstc2_acts_sys and action.intent in self.dstc2_acts_sys:
                return self.dstc2_acts_sys.index(action.intent)

            if slot:
                if action.intent == 'request' and \
                        slot in self.system_requestable_slots:
                    return len(self.dstc2_acts_sys) + \
                           self.system_requestable_slots.index(slot)

                if action.intent == 'inform' and \
                        slot in self.requestable_slots:
                    return len(self.dstc2_acts_sys) + \
                           len(self.system_requestable_slots) + \
                           self.requestable_slots.index(slot)
        else:
            if self.dstc2_acts_usr and action.intent in self.dstc2_acts_usr:
                return self.dstc2_acts_usr.index(action.intent)

            if slot:
                if action.intent == 'request' and \
                        slot in self.requestable_slots:
                    return len(self.dstc2_acts_usr) + \
                           self.requestable_slots.index(slot)

                if action.intent == 'inform' and \
                        slot in self.requestable_slots:
                    return len(self.dstc2_acts_usr) + \
                           len(self.requestable_slots) + \
                           self.requestable_slots.index(slot)

        # Default fall-back action
        print('Supervised ({0}) policy action encoder warning: Selecting '
              'default action (unable to encode: {1})!'.format(
                  self.agent_role, action))
        return -1

    def decode_action(self, action_enc, system=True):
        """
        Decode the action, given the role. Note that does not have to match
        the agent's role, as the agent may be decoding another agent's action
        (e.g. a system decoding the previous user act).

        :param action_enc: action encoding to be decoded
        :param system: whether the role whose action we are decoding is a
                       'system'
        :return: the decoded action
        """

        if system:
            if action_enc < len(self.dstc2_acts_sys):
                return [DialogueAct(self.dstc2_acts_sys[action_enc], [])]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.system_requestable_slots[
                                action_enc - len(self.dstc2_acts_sys)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_sys) + \
                    len(self.system_requestable_slots) +\
                    len(self.requestable_slots):
                index = action_enc - \
                        len(self.dstc2_acts_sys) - \
                        len(self.system_requestable_slots)
                return [
                    DialogueAct('inform', [
                        DialogueActItem(self.requestable_slots[index],
                                        Operator.EQ, '')
                    ])
                ]

        else:
            if action_enc < len(self.dstc2_acts_usr):
                return [DialogueAct(self.dstc2_acts_usr[action_enc], [])]

            if action_enc < len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots):
                return [
                    DialogueAct('request', [
                        DialogueActItem(
                            self.requestable_slots[action_enc -
                                                   len(self.dstc2_acts_usr)],
                            Operator.EQ, '')
                    ])
                ]

            if action_enc < len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots):
                return [
                    DialogueAct('inform', [
                        DialogueActItem(
                            self.requestable_slots[
                                action_enc - len(self.dstc2_acts_usr) -
                                len(self.requestable_slots)], Operator.EQ, '')
                    ])
                ]

    def save(self, path=None):
        """
        Saves the policy model to the provided path

        :param path: path to save the model to
        :return:
        """

        # Don't save if not training
        if not self.is_training:
            return

        print('DEBUG: {0} learning rate is: {1}'.format(
            self.agent_role, self.policy_alpha))

        pol_path = path

        if not pol_path:
            pol_path = self.policy_path

        if not pol_path:
            pol_path = 'Models/Policies/supervised_policy_' + \
                       self.agent_role + '_' + str(self.agent_id)

        if self.sess is not None and self.is_training:
            save_path = self.tf_saver.save(self.sess, pol_path)
            print('Supervised DialoguePolicy model saved at: %s' % save_path)

    def load(self, path):
        """
        Load the policy model from the provided path

        :param path: path to load the model from
        :return:
        """

        pol_path = path

        if not pol_path:
            pol_path = self.policy_path

        if not pol_path:
            pol_path = 'Models/Policies/supervised_policy_' + \
                       self.agent_role + '_' + str(self.agent_id)

        if os.path.isfile(pol_path + '.meta'):
            self.policy_net = self.feed_forward_net_init()
            self.sess = tf.InteractiveSession()

            self.tf_saver = \
                tf.train.Saver(
                    var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                               scope=self.tf_scope))

            self.tf_saver.restore(self.sess, pol_path)

            print('Supervised DialoguePolicy model loaded from {0}.'.format(
                pol_path))

        else:
            print('WARNING! Supervised DialoguePolicy cannot load policy '
                  'model from {0}!'.format(pol_path))