Example #1
0
    def __init__(self, args):
        """
        Initializes the internal structures of the DummyStateTracker. Loads the
        DataBase and Ontology, retrieves the DataBase table name, and creates
        the Dialogue State.
        :param args:
        """

        super(DummyStateTracker, self).__init__()

        if 'ontology' not in args:
            raise AttributeError('DummyStateTracker: Please provide ontology!')
        if 'database' not in args:
            raise AttributeError('DummyStateTracker: Please provide database!')
        if 'domain' not in args:
            raise AttributeError('DummyStateTracker: Please provide domain!')

        ontology = args['ontology']
        database = args['database']
        domain = args['domain']

        self.ontology = None
        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase):
            self.database = database

        elif isinstance(database, str):
            if database[:-3] == '.db':
                self.database = SQLDataBase(database)
            elif database[:-5] == '.json':
                self.database = JSONDataBase(database)
            else:
                raise ValueError('Unacceptable database type %s ' % database)

        else:
            raise ValueError('Unacceptable datbase type %s ' % database)

        # Get Table name
        self.db_table_name = self.database.get_table_name()

        self.DB_ITEMS = 0  # This will raise an error!

        self.domain = domain
        if domain in ['CamRest', 'SlotFilling']:
            self.DState = \
                SlotFillingDialogueState(
                    {'slots': self.ontology.ontology['system_requestable']})
        else:
            print('Warning! Domain has not been defined. Using Slot-Filling '
                  'Dialogue State')
            self.DState = \
                SlotFillingDialogueState(
                    {'slots': self.ontology.ontology['system_requestable']})
Example #2
0
    def __init__(self, args):
        """
        Initialize the Ludwig model and the dialogue state.

        :param args: a dictionary containing the path to the Ludwig model
        """
        super(CamRestLudwigDST, self).__init__(args)

        self.DState = SlotFillingDialogueState([])
Example #3
0
    def _build_text_field(self, domain: Domain):
        dings = state_to_json(SlotFillingDialogueState([]))
        tokens = [
            v for vv in domain._asdict().values() if isinstance(vv, list)
            for v in vv
        ]

        def regex_tokenizer(text,
                            pattern=r"(?u)(?:\b\w\w+\b|\S)") -> List[str]:
            return [m.group() for m in re.finditer(pattern, text)]

        state_tokens = [t for t in regex_tokenizer(dings) if t != '"']
        special_tokens = [str(k) for k in range(10)] + state_tokens
        text_field = Field(batch_first=True, tokenize=regex_tokenizer)
        text_field.build_vocab([tokens + special_tokens])
        return text_field
Example #4
0
    def next_action(self, state: SlotFillingDialogueState):
        self.agent.eval()
        self.agent.to(DEVICE)

        state_enc = self.encode_state(state).to(DEVICE)
        with torch.no_grad():
            if self.is_training and random.random() < self.epsilon:
                warmup_acts = self.warmup_policy.next_action(state)
                sys_acts = warmup_acts
                value = self.agent.calc_value(state_enc).cpu().item()
            else:
                agent_step = self.agent.step(state_enc)
                value = agent_step.v_values.cpu().item()
                sys_acts = [self.decode_action(agent_step)]
            state.value = value

        return sys_acts
Example #5
0
    def __init__(self,
                 ontology,
                 database,
                 agent_id=0,
                 agent_role='system',
                 domain=None,
                 alpha=0.2,
                 epsilon=0.95,
                 gamma=0.95,
                 alpha_decay=0.995,
                 epsilon_decay=0.9995,
                 epsilon_min=0.05):
        """
        Initialize parameters and internal structures

        :param ontology: the domain's ontology
        :param database: the domain's database
        :param agent_id: the agent's id
        :param agent_role: the agent's role
        :param alpha: the learning rate
        :param gamma: the discount rate
        :param epsilon: the exploration rate
        :param alpha_decay: the learning rate discount rate
        :param epsilon_decay: the exploration rate discount rate
        """

        super(ReinforcePolicy, self).__init__()

        self.logger = logging.getLogger(__name__)

        self.agent_id = agent_id
        self.agent_role = agent_role

        self.IS_GREEDY = False

        self.ontology = None
        if isinstance(ontology, Ontology.Ontology):
            self.ontology = ontology
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase.DataBase):
            self.database = database
        else:
            raise ValueError('Reinforce DialoguePolicy: Unacceptable '
                             'database type %s ' % database)

        self.policy_path = None

        self.weights = None
        self.sess = None

        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.alpha_decay_rate = alpha_decay
        self.exploration_decay_rate = epsilon_decay
        self.epsilon_min = epsilon_min

        # System and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        if self.agent_role == 'system':
            # Put your system expert policy here
            self.warmup_policy = HandcraftedPolicy(self.ontology)

        elif self.agent_role == 'user':
            usim_args = \
                dict(
                    zip(['ontology', 'database'],
                        [self.ontology, self.database]))
            # Put your user expert policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        self.tf_scope = "policy_" + self.agent_role + '_' + str(self.agent_id)

        # Default value
        self.is_training = True

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        if not domain:
            # Default to CamRest dimensions
            self.NStateFeatures = 56

            # Default to CamRest actions
            self.dstc2_acts = [
                'inform', 'offer', 'request', 'canthelp', 'affirm', 'negate',
                'deny', 'ack', 'thankyou', 'bye', 'reqmore', 'hello',
                'welcomemsg', 'expl-conf', 'select', 'repeat', 'reqalts',
                'confirm-domain', 'confirm'
            ]
        else:
            # Try to identify number of state features
            if domain in ['CamRest', 'SFH', 'SlotFilling']:
                d_state = \
                    SlotFillingDialogueState(
                        {'slots': self.system_requestable_slots})

                # Plato does not use action masks (rules to define which
                # actions are valid from each state) and so training can
                # be harder. This becomes easier if we have a smaller
                # action set.

                # Sub-case for CamRest
                if domain == 'CamRest':
                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_sys = [
                        'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye',
                        'reqmore', 'welcomemsg', 'expl-conf', 'select',
                        'repeat', 'confirm-domain', 'confirm'
                    ]

                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_usr = [
                        'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye',
                        'reqmore', 'hello', 'expl-conf', 'repeat', 'reqalts',
                        'restart', 'confirm'
                    ]

            else:
                self.logger.warning(
                    'Warning! Domain has not been defined. Using '
                    'Slot-Filling Dialogue State')
                d_state = \
                    SlotFillingDialogueState({'slots': self.informable_slots})

            d_state.initialize()
            self.NStateFeatures = len(self.encode_state(d_state))

            self.logger.info(
                'Reinforce DialoguePolicy {0} automatically determined '
                'number of state features: {1}'.format(self.agent_role,
                                                       self.NStateFeatures))

        if domain == 'CamRest' and self.dstc2_acts_sys:
            if self.agent_role == 'system':
                self.NActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_usr) + \
                    2 * len(self.requestable_slots)

            elif self.agent_role == 'user':
                self.NActions = \
                    len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

        else:
            if self.agent_role == 'system':
                self.NActions = \
                    3 + len(self.system_requestable_slots) + \
                    len(self.requestable_slots)

                self.NOtherActions = \
                    2 + len(self.requestable_slots) +\
                    len(self.requestable_slots)

            elif self.agent_role == 'user':
                self.NActions = \
                    2 + len(self.requestable_slots) + \
                    len(self.requestable_slots)

                self.NOtherActions = \
                    3 + len(self.system_requestable_slots) + \
                    len(self.requestable_slots)

        self.logger.info(
            'Reinforce {0} DialoguePolicy Number of Actions: {1}'.format(
                self.agent_role, self.NActions))
Example #6
0
class CamRestLudwigDST(LudwigDST):
    def __init__(self, args):
        """
        Initialize the Ludwig model and the dialogue state.

        :param args: a dictionary containing the path to the Ludwig model
        """
        super(CamRestLudwigDST, self).__init__(args)

        self.DState = SlotFillingDialogueState([])

    def initialize(self):
        """
        Initialize the dialogue state

        :return: nothing
        """

        super(CamRestLudwigDST, self).initialize()
        self.DState.initialize()

    def update_state(self, dacts):
        """
        Updates the current dialogue state given the input dialogue acts. This
        function will process the input, query the ludwig model, and retrieve
        the updated dialogue state.

        :param dacts: the input dialogue acts (usually NLU output)
        :return: the current dialogue state
        """

        self.DState.user_acts = deepcopy(dacts)

        input_data = {'transcription': '',
                      'prev_ds_area': ['na'],
                      'prev_ds_food': ['na'],
                      'prev_ds_pricerange': ['na'],
                      'prev_ds_requested': ['na'],
                      'prev_ds_user_terminating': [False],
                      'act': ['na'],
                      'inform_area': ['na'],
                      'inform_food': ['na'],
                      'inform_pricerange': ['na'],
                      'request_area': [False],
                      'request_food': [False],
                      'request_pricerange': [False],
                      'request_addr': [False],
                      'request_name': [False],
                      'request_phone': [False],
                      'request_postcode': [False]}

        if self.DState.slots_filled['area']:
            input_data['prev_ds_area'] = [self.DState.slots_filled['area']]

        if self.DState.slots_filled['food']:
            input_data['prev_ds_food'] = [self.DState.slots_filled['food']]

        if self.DState.slots_filled['pricerange']:
            input_data['prev_ds_pricerange'] = \
                [self.DState.slots_filled['pricerange']]

        if self.DState.slots_filled['requested']:
            input_data['prev_ds_requested'] = \
                [self.DState.slots_filled['requested']]

        if self.DState.is_terminal_state:
            input_data['prev_ds_user_terminating'] = \
                [self.DState.is_terminal_state]

        for dact in dacts:
            input_data['act'] = dact.intent

            if dact.intent in ['inform', 'request']:
                for item in dact.params:
                    input_data[dact.intent + '_' + item.slot] = [item.value]

            # Warning: Make sure the same tokenizer that was used to train
            # the model is used during prediction
            result = self.model.predict(pd.DataFrame(data=input_data))

            self.DState.slots_filled['area'] = \
                result['ds_area_predictions'][0] \
                if result['ds_area_predictions'][0] != 'na' else ''

            self.DState.slots_filled['food'] = \
                result['ds_food_predictions'][0] \
                if result['ds_food_predictions'][0] != 'na' else ''

            self.DState.slots_filled['pricerange'] = \
                result['ds_pricerange_predictions'][0] \
                if result['ds_pricerange_predictions'][0] != 'na' else ''

            # TODO: only the first requested slot is used.
            self.DState.requested_slots[0] = \
                result['ds_requested_predictions'][0] \
                if result['ds_requested_predictions'][0] != 'na' else ''

            self.DState.is_terminal_state = \
                result['ds_user_terminating_predictions'][0]

        self.DState.turn += 1

        return self.DState

    def update_state_db(self, db_result):
        """
        Updates the current database results in the dialogue state.

        :param db_result: a dictionary containing the database query results
        :return:
        """

        if db_result:
            self.DState.db_matches_ratio = len(db_result)
            self.DState.item_in_focus = db_result[0]

        return self.DState

    def update_state_sysact(self, sys_acts):
        """
        Updates the state, given the last system act. This is
        useful as we may want to update parts of the state given NLU output
        and then update again once the system produces a response.

        :param sys_acts:
        :return:
        """

        for sys_act in sys_acts:
            if sys_act.intent == 'offer':
                self.DState.system_made_offer = True

            # Reset the request if the system asks for more information,
            # assuming that any previously offered item
            # is now invalid.
            elif sys_act.intent == 'request':
                self.DState.system_made_offer = False

    def get_state(self):
        """
        Returns the current dialogue state

        :return: the current dialogue state
        """
        return self.DState

    def train(self, dialogue_episodes):
        """
        Not implemented.

        We can use Ludwig's API to train the model online (i.e. for a single
        dialogue).

        :param data: dialogue experience
        :return:
        """
        pass

    def save(self, model_path=None):
        """
        Saves the Ludwig model.

        :param model_path: path to save the model to
        :return:
        """
        super(CamRestLudwigDST, self).save(model_path)

    def load(self, model_path):
        """
        Loads the Ludwig model from the given path.

        :param model_path: path to the model
        :return:
        """

        super(CamRestLudwigDST, self).load(model_path)
Example #7
0
class DummyStateTracker(DialogueStateTracker):
    def __init__(self, args):
        """
        Initializes the internal structures of the DummyStateTracker. Loads the
        DataBase and Ontology, retrieves the DataBase table name, and creates
        the Dialogue State.
        :param args:
        """

        super(DummyStateTracker, self).__init__()

        if 'ontology' not in args:
            raise AttributeError('DummyStateTracker: Please provide ontology!')
        if 'database' not in args:
            raise AttributeError('DummyStateTracker: Please provide database!')
        if 'domain' not in args:
            raise AttributeError('DummyStateTracker: Please provide domain!')

        ontology = args['ontology']
        database = args['database']
        domain = args['domain']

        self.ontology = None
        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase):
            self.database = database

        elif isinstance(database, str):
            if database[:-3] == '.db':
                self.database = SQLDataBase(database)
            elif database[:-5] == '.json':
                self.database = JSONDataBase(database)
            else:
                raise ValueError('Unacceptable database type %s ' % database)

        else:
            raise ValueError('Unacceptable datbase type %s ' % database)

        # Get Table name
        self.db_table_name = self.database.get_table_name()

        self.DB_ITEMS = 0  # This will raise an error!

        self.domain = domain
        if domain in ['CamRest', 'SlotFilling']:
            self.DState = \
                SlotFillingDialogueState(
                    {'slots': self.ontology.ontology['system_requestable']})
        else:
            print('Warning! Domain has not been defined. Using Slot-Filling '
                  'Dialogue State')
            self.DState = \
                SlotFillingDialogueState(
                    {'slots': self.ontology.ontology['system_requestable']})

    def initialize(self, args=None):
        """
        Initializes the database results and dialogue state.

        :param args:
        :return:
        """

        cursor = self.database.SQL_connection.cursor()
        cursor.execute("SELECT * FROM " + self.db_table_name)
        tmp = cursor.fetchall()
        self.DB_ITEMS = len(tmp)

        if self.DB_ITEMS <= 0:
            print('Warning! DST could not get number of DB items.')
            self.DB_ITEMS = 110  # Default for CamRestaurants

        self.DState.initialize(args)
        # No constraints have been expressed yet
        self.DState.db_matches_ratio = 1.0

        self.DState.turn = 0

    def update_state(self, dacts):
        """
        Update the dialogue state given the input dialogue acts. This function
        basically tracks which intents, slots, and values have been mentioned
        and updates the dialogue state accordingly.

        :param dacts: a list of dialogue acts (usually the output of NLU)
        :return: the updated dialogue state
        """

        # TODO: These rules will create a field in the dialogue state slots
        # filled dictionary if one doesn't exist.
        self.DState.user_acts = deepcopy(dacts)

        # Reset past request
        self.DState.requested_slots = []
        self.DState.user_affirmed_last_sys_acts = False
        self.DState.user_denied_last_sys_acts = False

        for dact in dacts:
            if dact.intent == 'affirm':
                # The user affirms a explicit confirmation
                self.DState.user_affirmed_last_sys_acts = True
                # store confirmed slots
                if self.DState.last_sys_acts:
                    ec_acts = [
                        act for act in self.DState.last_sys_acts
                        if act.intent == 'expl-conf'
                    ]
                    for act in ec_acts:
                        if act.params:
                            for p in act.params:
                                self.DState.slots_confirmed[p.slot] = True

            if dact.intent == 'deny':
                # The user denies an explicit confirmation
                self.DState.user_denied_last_sys_acts = True

            if dact.intent in ['inform', 'offer']:
                # The user provided new information so the system hasn't made
                # any offers taking that into account yet.
                # self.DState.system_made_offer = False

                if dact.intent == 'offer':
                    self.DState.system_made_offer = True

                for dact_item in dact.params:
                    if dact_item.slot in self.DState.slots_filled:
                        self.DState.slots_filled[dact_item.slot] = \
                            dact_item.value

                    elif self.DState.user_goal:
                        if dact_item.slot in \
                                self.DState.user_goal.actual_requests:
                            self.DState.user_goal.actual_requests[
                                dact_item.slot].value = dact_item.value

                        # Only update requests that have been asked for
                        if dact_item.slot in self.DState.user_goal.requests:
                            self.DState.user_goal.requests[
                                dact_item.slot].value = dact_item.value

            elif dact.intent == 'request':
                for dact_item in dact.params:
                    # THIS APPLIES TO THE FOLLOWING RULES AS WELL

                    if dact_item.slot == 'slot' and dact_item.value:
                        # Case where we have request(slot = slot_name)
                        self.DState.requested_slots.append(dact_item.value)
                    else:
                        # Case where we have: request(slot_name)
                        self.DState.requested_slots.append(dact_item.slot)

            elif dact.intent == 'bye':
                self.DState.is_terminal_state = True

        # Increment turn
        self.DState.turn += 1

        return self.DState

    def update_state_db(self,
                        db_result=None,
                        sys_req_slot_entropies=None,
                        sys_acts=None):
        """
        This is a special function that is mostly designed for the multi-agent
        setup. If the state belongs to a 'system' agent, then this function
        will update the current database results. If the state belongs to a
        'user' agent, then this function will update the 'item in focus' fields
        of the dialogue state, given the last system action.

        :param db_result: a dictionary containing the database query results
        :param sys_req_slot_entropies: calculated entropies for requestable
                                       slots
        :param sys_acts: the system's acts
        :return:
        """

        if db_result and sys_acts:
            raise ValueError('Dialogue State Tracker: Cannot update state as '
                             'both system and user (i.e. please use only one '
                             'argument as appropriate).')

        # This should be called if the agent is a system
        if db_result:
            self.DState.db_matches_ratio = \
                float(len(db_result) / self.DB_ITEMS)

            if db_result[0] == 'empty':
                self.DState.item_in_focus = []

            else:
                self.DState.item_in_focus = db_result[0]

            if sys_req_slot_entropies:
                self.DState.system_requestable_slot_entropies = \
                    deepcopy(sys_req_slot_entropies)

            self.DState.db_result = db_result

        # This should be called if the agent is a user
        elif sys_acts:
            # Create dictionary if it doesn't exist or reset it if a new offer
            # has been made
            if not self.DState.item_in_focus or \
                    'offer' in [a.intent for a in sys_acts]:
                self.DState.item_in_focus = \
                    dict.fromkeys(self.ontology.ontology['requestable'])

            for sys_act in sys_acts:
                if sys_act.intent in ['inform', 'offer']:
                    for item in sys_act.params:
                        self.DState.item_in_focus[item.slot] = item.value

                        if self.DState.user_goal:
                            if item.slot in \
                                    self.DState.user_goal.actual_requests:
                                self.DState.user_goal.actual_requests[
                                    item.slot].value = item.value

                            # Only update requests that have been asked for
                            if item.slot in self.DState.user_goal.requests:
                                self.DState.user_goal.requests[
                                    item.slot].value = item.value

        return self.DState

    def update_state_sysact(self, sys_acts):
        """
        Updates the last system act and the goal, given that act. This is
        useful as we may want to update parts of the state given NLU output
        and then update again once the system produces a response.

        :param sys_acts: the last system acts
        :return:
        """

        if sys_acts:
            self.DState.last_sys_acts = sys_acts

            for sys_act in sys_acts:

                if sys_act.intent == 'offer':
                    self.DState.system_made_offer = True

                # Keep track of actual requests made. These are used in reward
                # and success calculation for systems. The
                # reasoning is that it does not make sense to penalise a system
                # for an unanswered request that was
                # never actually made by the user.
                # If the current agent is a system then these will be
                # disregarded.
                if sys_act.intent == 'request' and sys_act.params and \
                        self.DState.user_goal:
                    self.DState.user_goal.actual_requests[
                        sys_act.params[0].slot] = sys_act.params[0]

                # Similarly, keep track of actual constraints made.
                if sys_act.intent == 'inform' and sys_act.params and \
                        self.DState.user_goal:
                    self.DState.user_goal.actual_constraints[
                        sys_act.params[0].slot] = sys_act.params[0]

                # Reset the request if the system asks for more information,
                # assuming that any previously offered item
                # is now invalid.
                # elif sys_act.intent == 'request':
                #     self.DState.system_made_offer = False

    def update_goal(self, goal):
        """
        Updates the agent's goal

        :param goal: a Goal
        :return:
        """

        # TODO: Do a deep copy?
        self.DState.user_goal = goal

    def train(self, data):
        """
        Nothing to do here.

        :param data:
        :return:
        """
        pass

    def get_state(self):
        """
        Returns the current dialogue state.

        :return: the current dialogue state
        """
        return self.DState

    def save(self, path=None):
        """
        Nothing to do here.

        :param path:
        :return:
        """
        pass

    def load(self, path):
        """
        Nothing to do here.

        :param path:
        :return:
        """
        pass
    def __init__(self,
                 ontology,
                 database,
                 agent_id=0,
                 agent_role='system',
                 domain=None):
        """
        Load the ontology and database and initialize some internal structures

        :param ontology: the domain ontology
        :param database: the domain database
        :param agent_id: the agent's id
        :param agent_role: the agent's role
        :param domain: the dialogue's domain
        """
        super(CalculatedPolicy, self).__init__()

        self.agent_id = agent_id
        self.agent_role = agent_role

        # True for greedy, False for stochastic
        self.IS_GREEDY_POLICY = True

        self.ontology = None
        if isinstance(ontology, Ontology.Ontology):
            self.ontology = ontology
        else:
            raise ValueError('Calculated DialoguePolicy: Unacceptable '
                             'ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase.DataBase):
            self.database = database
        else:
            raise ValueError('Calculated DialoguePolicy: Unacceptable '
                             'database type %s ' % database)

        self.policy_path = None

        self.policy = None

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'] + ['this'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        self.dstc2_acts = None

        # Plato does not use action masks (rules to define which
        # actions are valid from each state) and so training can
        # be harder. This becomes easier if we have a smaller
        # action set.

        if not domain:
            # Default to CamRest dimensions
            self.NStateFeatures = 56

            # Default to CamRest actions.
            # Does not include inform and request that are modelled together
            # with their arguments
            self.dstc2_acts = [
                'offer', 'canthelp', 'affirm', 'negate', 'deny', 'ack',
                'thankyou', 'bye', 'reqmore', 'hello', 'welcomemsg',
                'expl-conf', 'select', 'repeat', 'reqalts', 'confirm-domain',
                'confirm'
            ]
        else:
            # Try to identify number of state features
            if domain in ['CamRest', 'SFH', 'SlotFilling']:
                d_state = \
                    SlotFillingDialogueState(
                        {
                            'slots':
                                self.ontology.ontology['system_requestable']})

                # Sub-case for CamRest
                if domain == 'CamRest':
                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts = [
                        'offer', 'canthelp', 'affirm', 'negate', 'deny', 'ack',
                        'thankyou', 'bye', 'reqmore', 'hello', 'welcomemsg',
                        'expl-conf', 'select', 'repeat', 'reqalts',
                        'confirm-domain', 'confirm'
                    ]

            else:
                print('Warning! Domain has not been defined. Using Dummy '
                      'Dialogue State')
                d_state = \
                    SlotFillingDialogueState(
                        {
                            'slots':
                                self.ontology.ontology['system_requestable']})

            d_state.initialize()

        if self.dstc2_acts:
            self.NActions = len(self.dstc2_acts) + len(self.requestable_slots)

            if self.agent_role == 'system':
                self.NActions += len(self.system_requestable_slots)

            elif self.agent_role == 'user':
                self.NActions += len(self.requestable_slots)
        else:
            if self.agent_role == 'system':
                self.NActions = 4 + len(self.system_requestable_slots) + \
                                len(self.requestable_slots)

            elif self.agent_role == 'user':
                self.NActions = 3 + 2 * len(self.requestable_slots)
Example #9
0
    def __init__(self,
                 ontology,
                 database,
                 agent_id=0,
                 agent_role='system',
                 domain=None):
        """
        Initialize parameters and internal structures

        :param ontology: the domain's ontology
        :param database: the domain's database
        :param agent_id: the agent's id
        :param agent_role: the agent's role
        :param domain: the dialogue's domain
        """
        super(SupervisedPolicy, self).__init__()

        self.agent_id = agent_id
        self.agent_role = agent_role

        # True for greedy, False for stochastic
        self.IS_GREEDY_POLICY = False

        self.ontology = None
        if isinstance(ontology, Ontology.Ontology):
            self.ontology = ontology
        else:
            raise ValueError('Supervised DialoguePolicy: Unacceptable '
                             'ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase.DataBase):
            self.database = database
        else:
            raise ValueError('Supervised DialoguePolicy: Unacceptable '
                             'database type %s ' % database)

        self.policy_path = None

        self.policy_net = None
        self.tf_scope = "policy_" + self.agent_role + '_' + str(self.agent_id)
        self.sess = None

        # The system and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        # Default value
        self.is_training = True

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'] +
                     ['this', 'signature'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        self.dstc2_acts = None

        if not domain:
            # Default to CamRest dimensions
            self.NStateFeatures = 56

            # Default to CamRest actions
            self.dstc2_acts = [
                'repeat', 'canthelp', 'affirm', 'negate', 'deny', 'ack',
                'thankyou', 'bye', 'reqmore', 'hello', 'welcomemsg',
                'expl-conf', 'select', 'offer', 'reqalts', 'confirm-domain',
                'confirm'
            ]
        else:
            # Try to identify number of state features
            if domain in ['SlotFilling', 'CamRest']:
                DState = \
                    SlotFillingDialogueState(
                        {'slots': self.system_requestable_slots})

                # Plato does not use action masks (rules to define which
                # actions are valid from each state) and so training can
                # be harder. This becomes easier if we have a smaller
                # action set.

                # Sub-case for CamRest
                if domain == 'CamRest':
                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_sys = [
                        'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye',
                        'reqmore', 'welcomemsg', 'expl-conf', 'select',
                        'repeat', 'confirm-domain', 'confirm'
                    ]

                    # Does not include inform and request that are modelled
                    # together with their arguments
                    self.dstc2_acts_usr = [
                        'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye',
                        'reqmore', 'hello', 'expl-conf', 'repeat', 'reqalts',
                        'restart', 'confirm'
                    ]

                    if self.agent_role == 'system':
                        self.dstc2_acts = self.dstc2_acts_sys

                    elif self.agent_role == 'user':
                        self.dstc2_acts = self.dstc2_acts_usr

            else:
                print('Warning! Domain has not been defined. Using '
                      'Slot-Filling Dialogue State')
                DState = \
                    SlotFillingDialogueState({'slots': self.informable_slots})

            DState.initialize()
            self.NStateFeatures = len(self.encode_state(DState))
            print('Supervised DialoguePolicy automatically determined number '
                  'of state features: {0}'.format(self.NStateFeatures))

        if domain == 'CamRest':
            self.NActions = len(self.dstc2_acts) + len(self.requestable_slots)

            if self.agent_role == 'system':
                self.NActions += len(self.system_requestable_slots)
            else:
                self.NActions += len(self.requestable_slots)
        else:
            self.NActions = 5

        self.policy_alpha = 0.05

        self.tf_saver = None