def dialog_train(self, dialog):
        # create entity tracker
        et = EntityTracker()
        # create action tracker
        at = ActionTracker(et)
        # reset network
        self.net.reset_state()

        loss = 0.
        # iterate through dialog
        for (u,r) in dialog:
            # print("Here in the dialog loop")
            u_ent = et.extract_entities(u)
            u_ent_features = et.context_features()
            u_emb = self.emb.encode(u)
            u_bow = self.bow_enc.encode(u)
            # concat features
            features = torch.autograd.Variable(torch.from_numpy(np.concatenate((u_ent_features, u_emb, u_bow), axis=0))).float()
            # print(features)
            # get action mask
            action_mask = torch.autograd.Variable(torch.from_numpy(at.action_mask()))
            r = torch.autograd.Variable(torch.LongTensor([r]))
            # print(r)
            # forward propagation
            #  train step
            loss += self.net.train_step(features, r, action_mask)
        return loss/len(dialog)
예제 #2
0
    def dialog_train(self, dialog):
        ###################################################################
        if self.lang_type == 'eng':
            from modules.entities import EntityTracker
            from modules.data_utils import Data
            from modules.actions import ActionTracker
            from modules.bow import BoW_encoder
        elif self.lang_type == 'kor':
            from modules.entities_kor import EntityTracker
            from modules.data_utils_kor import Data
            from modules.actions_kor import ActionTracker
            from modules.bow_kor import BoW_encoder
        ###################################################################

        et = EntityTracker()
        at = ActionTracker(et)
        # reset state in network_type
        self.net.reset_state()

        loss = 0.
        for (u, r) in dialog:
            u_ent = et.extract_entities(u)
            u_ent_features = et.context_features()
            u_bow = self.bow_enc.encode(u)
            if self.is_emb:
                u_emb = self.emb.encode(u)
            if self.is_action_mask:
                action_mask = at.action_mask()

            # print(u, r)
            # print(u_ent_features)
            # print('================================')
            # print(u_emb)
            # print('================================')
            # print(u_bow)
            # print('================================')
            # print(action_mask)

            # concatenated features
            if self.is_action_mask and self.is_emb:
                features = np.concatenate(
                    (u_ent_features, u_emb, u_bow, action_mask), axis=0)
            elif self.is_action_mask and not (self.is_emb):
                features = np.concatenate((u_ent_features, u_bow, action_mask),
                                          axis=0)
            elif not (self.is_action_mask) and self.is_emb:
                features = np.concatenate((u_ent_features, u_emb, u_bow),
                                          axis=0)
            elif not (self.is_action_mask) and not (self.is_emb):
                features = np.concatenate((u_ent_features, u_bow), axis=0)

            # forward propagation with cumulative loss
            if self.is_action_mask:
                loss += self.net.train_step(features, r, action_mask)
            else:
                loss += self.net.train_step(features, r)

        return loss / len(dialog)
    def evaluate(self):
        # create entity tracker
        et = EntityTracker()
        # create action tracker
        at = ActionTracker(et)
        # reset network
        self.net.reset_state()

        dialog_accuracy = 0.
        r_count = 0 #Count of task 15
        count = 0 # Total count of rewards
        for dialog_idx in self.dialog_indices_dev:

            start, end = dialog_idx['start'], dialog_idx['end']
            dialog = self.dataset[start:end]
            num_dev_examples = len(self.dialog_indices_dev)

            # create entity tracker
            et = EntityTracker()
            # create action tracker
            at = ActionTracker(et)
            # reset network
            self.net.reset_state()

            # iterate through dialog
            correct_examples = 0
            for (u,r) in dialog:
                # encode utterance
                u_ent = et.extract_entities(u)
                u_ent_features = et.context_features()
                u_emb = self.emb.encode(u)
                u_bow = self.bow_enc.encode(u)
                # concat features
                # get action mask
                features = torch.autograd.Variable(torch.from_numpy(np.concatenate((u_ent_features, u_emb, u_bow), axis=0))).float()
                # print(features)
                # get action mask
                action_mask = torch.autograd.Variable(torch.from_numpy(at.action_mask()))
                # r = torch.autograd.Variable(torch.LongTensor([r]))
                # forward propagation
                #  train step
                logits,probs,prediction = self.net.forward(features, action_mask)
                # print("logits", logits)
                # print("probs", probs)
                # print("prediction", logits,probs,prediction)
                # print(prediction,r)
                correct_examples += int(prediction == r)
                if r==15:
                    r_count += 1
                count += 1
            # get dialog accuracy
            dialog_accuracy += correct_examples/len(dialog)

        print("task 15 was the answer with freq",r_count/count)

        return dialog_accuracy/num_dev_examples
예제 #4
0
    def interact(self):
        # create entity tracker
        et = EntityTracker()
        # create action tracker
        at = ActionTracker(et)
        # reset network
        self.net.reset_state()

        # begin interaction loop
        while True:

            # get input from user
            u = input(':: ')

            # check if user wants to begin new session
            if u == 'clear' or u == 'reset' or u == 'restart':
                self.net.reset_state()
                et = EntityTracker()
                at = ActionTracker(et)
                print('')

            # check for exit command
            elif u == 'exit' or u == 'stop' or u == 'quit' or u == 'q':
                break

            else:
                # ENTER press : silence
                if not u:
                    u = '<SILENCE>'

                # encode
                u_ent, u_entities = et.extract_entities(u, is_test=True)
                u_ent_features = et.context_features()

                u_emb = self.emb.encode(u)
                u_bow = self.bow_enc.encode(u)
                # concat features
                features = np.concatenate((u_ent_features, u_emb, u_bow),
                                          axis=0)
                # get action mask
                action_mask = at.action_mask()

                # forward
                prediction = self.net.forward(features, action_mask)

                if self.post_process(prediction, u_ent_features):
                    print(
                        '>>', 'api_call ' + u_entities['<cuisine>'] + ' ' +
                        u_entities['<location>'] + ' ' +
                        u_entities['<party_size>'] + ' ' +
                        u_entities['<rest_type>'])
                else:
                    prediction = self.action_post_process(
                        prediction, u_entities)
                    print('>>', self.action_templates[prediction])
    def evaluate(self):
        # create entity tracker
        et = EntityTracker()
        # create action tracker
        at = ActionTracker(et)
        # reset network
        self.net.reset_state()

        dialog_accuracy = 0.
        correct_dialogue_count = 0
        for dialog_idx in self.dialog_indices_dev:

            start, end = dialog_idx['start'], dialog_idx['end']
            dialog = self.dataset[start:end]
            num_dev_examples = len(self.dialog_indices_dev)

            # create entity tracker
            et = EntityTracker()
            # create action tracker
            at = ActionTracker(et)
            # reset network
            self.net.reset_state()

            # iterate through dialog
            correct_examples = 0
            for (u, r) in dialog:
                # encode utterance
                u_ent = et.extract_entities(u)
                u_ent_features = et.context_features()
                u_emb = self.emb.encode(u)
                u_bow = self.bow_enc.encode(u)
                # concat features
                features = np.concatenate((u_ent_features, u_emb, u_bow),
                                          axis=0)
                # get action mask
                action_mask = at.action_mask()
                # forward propagation
                #  train step
                prediction = self.net.forward(features, action_mask)
                correct_examples += int(prediction == r)

            if correct_examples == len(dialog):
                correct_dialogue_count += 1
            # get dialog accuracy
            dialog_accuracy += correct_examples / len(dialog)
        per_response_accuracy = dialog_accuracy / num_dev_examples
        per_dialogue_accuracy = correct_dialogue_count / num_dev_examples
        return per_response_accuracy, per_dialogue_accuracy
예제 #6
0
    def interact(self, input):
        # create entity tracker
        et = EntityTracker()
        # create action tracker
        at = ActionTracker(et)
        # reset network
        self.net.reset_state()

        # begin interaction loop
        #while True:

        # get input from user
        #u = input(':: ')
        u = input
        # check if user wants to begin new session
        if u == 'clear' or u == 'reset' or u == 'restart':
            self.net.reset_state()
            et = EntityTracker()
            at = ActionTracker(et)
            print('')

        # check for exit command
        #elif u == 'exit' or u == 'stop' or u == 'quit' or u == 'q':
        #break

        else:
            # ENTER press : silence
            if not u:
                u = '<SILENCE>'

                # encode
            u_ent = et.extract_entities(u)
            u_ent_features = et.context_features()
            u_emb = self.emb.encode(u)
            u_bow = self.bow_enc.encode(u)
            # concat features
            features = np.concatenate((u_ent_features, u_emb, u_bow), axis=0)
            # get action mask
            action_mask = at.action_mask()

            # forward
            prediction = self.net.forward(features, action_mask)
            #print('>>', self.action_templates[prediction])
        return self.action_templates[prediction]
예제 #7
0
    def dialog_train(self, dialog):
        # create entity tracker
        et = EntityTracker()
        # create action tracker
        at = ActionTracker(et)
        # reset network
        self.net.reset_state()

        loss = 0.
        # iterate through dialog
        for (u, r) in dialog:
            u_ent = et.extract_entities(u)
            u_ent_features = et.context_features()
            u_emb = self.emb.encode(u)
            u_bow = self.bow_enc.encode(u)
            # concat features
            features = np.concatenate((u_ent_features, u_emb, u_bow), axis=0)
            # get action mask
            action_mask = at.action_mask()
            # forward propagation
            #  train step
            loss += self.net.train_step(features, r, action_mask)
        return loss/len(dialog)
예제 #8
0
    def interact(self):
        # create entity tracker
        et = EntityTracker()
        # create action tracker
        at = ActionTracker(et)
        # reset network
        self.net.reset_state()

        # begin interaction loop
        while True:

            # get input from user
            u = input('User: '******'clear' or u == 'reset' or u == 'restart':
                self.net.reset_state()
                et = EntityTracker()
                at = ActionTracker(et)
                print('Bot: Reset successfully')

            # check for entrance and exit command
            elif u == 'exit' or u == 'stop' or u == 'quit' or u == 'q':
                print("Bot: Thank you for using")
                break

            elif u == 'hello' or u == 'hi':
                print("Bot: Hello, what can i do for you")

            elif u == 'thank you' or u == 'thanks' or u == 'thank you very much':
                print('Bot: You are welcome')
                break

            else:
                if not u:
                    continue

                u = u.lower()

                # encode
                u_ent = et.extract_entities(u)
                u_ent_features = et.context_features()  # 5

                # print(et.entities)
                # print(et.ctxt_features)

                u_emb = self.emb.encode(u)              # 300
                u_bow = self.bow_enc.encode(u)          # 60
                # concat features
                features = np.concatenate((u_ent_features, u_emb, u_bow), axis=0)
                # print(features.shape)
                # get action mask
                action_mask = at.action_mask()
                # action_mask = np.ones(self.net.action_size)

                # print("action_mask: ", action_mask)

                # forward
                prediction = self.net.forward(features, action_mask)
                response = self.action_templates[prediction]
                if prediction == 0:
                    slot_values = copy.deepcopy(et.entities)
                    slot_values.pop('<location>')
                    memory = []
                    count = 0
                    for k, v in slot_values.items():
                        memory.append('='.join([k, v]))
                        count += 1
                        if count == 2:
                            memory.append('\n')

                    response = response.replace("memory", ', '.join(memory))

                    # memory = ', '.join(slot_values.values())
                    # response = response.replace("memory", memory)
                    self.net.reset_state()
                    et = EntityTracker()
                    at = ActionTracker(et)
                    # print('Execute successfully and begin new session')
                if prediction == 1:
                    response = response.replace("location", '<location>=' + et.entities['<location>'])
                print('Bot: ', response)
예제 #9
0
class Dialogue():
    def __init__(self):
        # stor whole dialogues
        self.story = []
        self.sp_confidecne = []
        self.file_path = os.path.join(
            rospkg.RosPack().get_path('dialogue_system'), 'log',
            'dialogue.txt')

        # count turn taking
        self.usr_count = 0
        self.sys_count = 0

        # paramaters
        self.network_type = rospy.get_param('~network_model', 'stacked_lstm')
        self.lang_type = rospy.get_param('~lang', 'eng')
        self.is_emb = rospy.get_param('~embedding', 'false')
        self.is_am = rospy.get_param('~action_mask', "true")
        self.user_num = rospy.get_param('~user_number', '0')

        # call rest of modules
        self.et = EntityTracker()
        self.at = ActionTracker(self.et)
        self.bow_enc = BoW_encoder()
        self.emb = UtteranceEmbed(lang=self.lang_type)

        # select observation size for RNN
        if self.is_am and self.is_emb:
            obs_size = self.emb.dim + self.bow_enc.vocab_size + self.et.num_features + self.at.action_size
        elif self.is_am and not (self.is_emb):
            obs_size = self.bow_enc.vocab_size + self.et.num_features + self.at.action_size
        elif not (self.is_am) and self.is_emb:
            obs_size = self.emb.dim + self.bow_enc.vocab_size + self.et.num_features
        elif not (self.is_am) and not (self.is_emb):
            obs_size = self.bow_enc.vocab_size + self.et.num_features

        self.action_template = self.at.get_action_templates()
        self.at.do_display_template()
        # must clear entities space
        self.et.do_clear_entities()
        action_size = self.at.action_size
        nb_hidden = 128

        if self.network_type == 'gru':
            self.net = GRU(obs_size=obs_size,
                           nb_hidden=nb_hidden,
                           action_size=action_size,
                           lang=self.lang_type,
                           is_action_mask=self.is_am)
        elif self.network_type == 'reversed_lstm':
            self.net = ReversingLSTM(obs_size=obs_size,
                                     nb_hidden=nb_hidden,
                                     action_size=action_size,
                                     lang=self.lang_type,
                                     is_action_mask=self.is_am)
        elif self.network_type == 'reversed_gru':
            self.net = ReversingGRU(obs_size=obs_size,
                                    nb_hidden=nb_hidden,
                                    action_size=action_size,
                                    lang=self.lang_type,
                                    is_action_mask=self.is_am)
        elif self.network_type == 'stacked_gru':
            self.net = StackedGRU(obs_size=obs_size,
                                  nb_hidden=nb_hidden,
                                  action_size=action_size,
                                  lang=self.lang_type,
                                  is_action_mask=self.is_am)
        elif self.network_type == 'stacked_lstm':
            self.net = StackedLSTM(obs_size=obs_size,
                                   nb_hidden=nb_hidden,
                                   action_size=action_size,
                                   lang=self.lang_type,
                                   is_action_mask=self.is_am)
        elif self.network_type == 'lstm':
            self.net = LSTM(obs_size=obs_size,
                            nb_hidden=nb_hidden,
                            action_size=action_size,
                            lang=self.lang_type,
                            is_action_mask=self.is_am)
        elif self.network_type == 'bidirectional_lstm':
            self.net = BidirectionalLSTM(obs_size=obs_size,
                                         nb_hidden=nb_hidden,
                                         action_size=action_size,
                                         lang=self.lang_type,
                                         is_action_mask=self.is_am)
        elif self.network_type == 'bidirectional_gru':
            self.net = BidirectionalGRU(obs_size=obs_size,
                                        nb_hidden=nb_hidden,
                                        action_size=action_size,
                                        lang=self.lang_type,
                                        is_action_mask=self.is_am)

        # restore trained model
        self.net.restore()

        # rostopics
        self.pub_reply = rospy.Publisher('reply', Reply, queue_size=10)
        self.pub_complete = rospy.Publisher('complete_execute_scenario',
                                            Empty,
                                            queue_size=10)
        rospy.Subscriber('raising_events', RaisingEvents,
                         self.handle_raise_events)

        try:
            rospy.wait_for_service('reception_db/query_data')
            self.get_response_db = rospy.ServiceProxy(
                'reception_db/query_data', DBQuery)
            rospy.logwarn("waiting for reception DB module...")
        except rospy.exceptions.ROSInterruptException as e:
            rospy.logerr(e)
            quit()
        rospy.logwarn(
            "network: {}, lang: {}, action_mask: {}, embedding: {}, user_number: {}"
            .format(self.network_type, self.lang_type, self.is_am, self.is_emb,
                    self.user_num))
        self.story.append('user number: %s' % self.user_num)
        rospy.loginfo('\033[94m[%s]\033[0m initialized.' % rospy.get_name())

        # if utterance == 'clear':
        #     self.net.reset_state()
        #     self.et.do_clear_entities()
        #     response = 'context has been cleared.'

    def get_response(self, utterance):
        rospy.loginfo("actual input: %s" %
                      utterance)  # check actual user input

        # clean utterance
        # utterance = re.sub(r'[^ a-z A-Z 0-9]', " ", utterance)
        # utterance preprocessing
        u_ent, u_entities = self.et.extract_entities(utterance, is_test=True)
        u_ent_features = self.et.context_features()
        u_bow = self.bow_enc.encode(utterance)

        if self.is_emb:
            u_emb = self.emb.encode(utterance)
        try:
            if self.is_am:
                action_mask = self.at.action_mask()

            # concatenated features
            if self.is_am and self.is_emb:
                features = np.concatenate(
                    (u_ent_features, u_emb, u_bow, action_mask), axis=0)
            elif self.is_am and not (self.is_emb):
                features = np.concatenate((u_ent_features, u_bow, action_mask),
                                          axis=0)
            elif not (self.is_am) and self.is_emb:
                features = np.concatenate((u_ent_features, u_emb, u_bow),
                                          axis=0)
            elif not (self.is_am) and not (self.is_emb):
                features = np.concatenate((u_ent_features, u_bow), axis=0)

            # try:
            # predict template number
            if self.is_am:
                probs, prediction = self.net.forward(features, action_mask)
            else:
                probs, prediction = self.net.forward(features)

            # check response confidence
            if max(probs) > BOUNDARY_CONFIDENCE:
                response = self.action_template[prediction]
                prediction = self.pre_action_process(prediction, u_entities)

                # handle api call
                if self.post_process(prediction, u_entities):
                    if prediction == 1:
                        response = 'api_call appointment {} {} {} {} {} {} {}'.format(
                            u_entities['<first_name>'],
                            u_entities['<last_name>'],
                            u_entities['<address_number>'],
                            u_entities['<address_name>'],
                            u_entities['<address_type>'], u_entities['<time>'],
                            u_entities['<pm_am>'])
                    elif prediction == 2:
                        response = 'api_call location {}'.format(
                            u_entities['<location>'])
                    elif prediction == 3:
                        response = 'api_call prescription {} {} {} {} {}'.format(
                            u_entities['<first_name>'],
                            u_entities['<last_name>'],
                            u_entities['<address_number>'],
                            u_entities['<address_name>'],
                            u_entities['<address_type>'])
                    elif prediction == 4:
                        response = 'api_call waiting_time {} {} {} {} {} {} {}'.format(
                            u_entities['<first_name>'],
                            u_entities['<last_name>'],
                            u_entities['<address_number>'],
                            u_entities['<address_name>'],
                            u_entities['<address_type>'], u_entities['<time>'],
                            u_entities['<pm_am>'])
                    response = self.get_response_db(
                        response
                    )  # query knowledge base; here we use dynamo db
                    response = response.response
                elif prediction in [6, 9, 11]:
                    response = self.action_template[prediction]
                    response = response.split(' ')
                    response = [
                        word.replace('<first_name>',
                                     u_entities['<first_name>'])
                        for word in response
                    ]
                    response = ' '.join(response)
                else:
                    response = self.action_template[prediction]

            else:
                response = random.choice(
                    REPROMPT
                )  # if prediction confidence less than 40%, reprompt
        except:
            response = random.choice(REPROMPT)

        return prediction, probs, response

    def handle_raise_events(self, msg):
        utterance = msg.recognized_word
        try:
            # get confidence
            data = json.loads(msg.data[0])
            confidence = data['confidence']
        except:
            confidence = None

        if confidence > BOUNDARY_CONFIDENCE or confidence == None:
            if 'silency_detected' in msg.events:
                utterance = '<SILENCE>'
            else:
                try:
                    self.story.append(
                        "U%i: %s (sp_conf:%f)" %
                        (self.usr_count + 1, utterance, confidence))
                    self.sp_confidecne.append(confidence)
                except:
                    self.story.append("U%i: %s" %
                                      (self.usr_count + 1, utterance))
                self.usr_count += 1
                utterance = utterance.lower()
            # generate system response
            prediction, probs, response = self.get_response(utterance)
        else:
            prediction = -1
            probs = -1
            response = random.choice(REPROMPT)

        # add system turn
        self.story.append("A%i: %s" % (self.sys_count + 1, response))
        self.sys_count += 1

        # finish interaction
        if (prediction == 6):
            self.pub_complete.publish()
            # logging user and system turn
            self.story.append("user: %i, system: %i" %
                              (self.usr_count, self.sys_count))
            self.story.append("mean_sp_conf: %f" %
                              (reduce(lambda x, y: x + y, self.sp_confidecne) /
                               len(self.sp_confidecne)))
            self.story.append(
                '==================================================================='
            )
            self.write_file(self.file_path, self.story)

        # display system response
        rospy.loginfo(json.dumps(self.et.entities,
                                 indent=2))  # recognized entity values
        try:
            rospy.logwarn("System: [conf: %f, predict: %d] / %s\n" %
                          (max(probs), prediction, response))
        except:
            rospy.logwarn("System: [] / %s\n" % (response))

        reply_msg = Reply()
        reply_msg.header.stamp = rospy.Time.now()
        reply_msg.reply = response

        self.pub_reply.publish(reply_msg)

    def post_process(self, prediction, u_ent_features):
        api_call_list = [1, 2, 3, 4]
        if prediction in api_call_list:
            return True
        attr_list = [0, 9, 10, 11, 12]
        if all(u_ent_featur == 1
               for u_ent_featur in u_ent_features) and prediction in attr_list:
            return True
        else:
            return False

    def action_post_process(self, prediction, u_entities):

        attr_mapping_dict = {
            11: '<first_name>',
            11: '<last_name>',
            12: '<address_number>',
            12: '<address_name>',
            12: '<address_type>',
            10: '<time>',
            10: '<pm_am>',
        }

        # find exist and non-exist entity
        exist_ent_index = [
            key for key, value in u_entities.items() if value != None
        ]
        non_exist_ent_index = [
            key for key, value in u_entities.items() if value == None
        ]

        # if predicted key is already in exist entity index then find non exist entity index
        # and leads the user to input non exist entity.

        if prediction in attr_mapping_dict:
            pred_key = attr_mapping_dict[prediction]
            if pred_key in exist_ent_index:
                for key, value in attr_mapping_dict.items():
                    if value == non_exist_ent_index[0]:
                        return key
            else:
                return prediction
        else:
            return prediction

    def pre_action_process(self, prediction, u_entities):

        api_call_list = [1, 3, 4]

        attr_mapping_dict = {
            '<first_name>': 11,
            '<last_name>': 11,
            '<address_number>': 12,
            '<address_name>': 12,
            '<address_type>': 12,
            '<time>': 10,
            '<pm_am>': 10,
        }

        # find exist and non-exist entity
        non_exist_ent_index = [
            key for key, value in u_entities.items() if value == None
        ]

        if prediction in api_call_list:
            if '<first_name>' in non_exist_ent_index:
                prediction = attr_mapping_dict['<first_name>']

        return prediction

    ''' 
    writing story log file
    '''

    def write_file(self, path, story_list):
        with open(path, 'a') as f:
            for item in story_list:
                f.write("%s\n" % item)
        rospy.logwarn('save dialogue histories.')
class InteractiveSession:
    def __init__(self):

        self.et = EntityTracker()
        self.at = ActionTracker(self.et)

        self.bow_enc = BoW_encoder()
        self.emb = UtteranceEmbed()

        obs_size = self.emb.dim + self.bow_enc.vocab_size + self.et.num_features
        self.action_templates = self.at.get_action_templates()
        action_size = self.at.action_size
        nb_hidden = 128

        self.net = LSTM_net(obs_size=obs_size,
                            action_size=action_size,
                            nb_hidden=nb_hidden)

        # restore checkpoint
        self.net.restore()
        self.net.reset_state()

    def reset(self):
        self.net.reset_state()
        self.et = EntityTracker()
        self.at = ActionTracker(self.et)

    def interact(self, utterance):
        # get input from user
        u = utterance.lower()

        # check if user wants to begin new session
        if u == 'clear' or u == 'reset' or u == 'restart':
            self.reset()
            return "reset successfully"

        # check for entrance and exit command
        elif u == 'exit' or u == 'stop' or u == 'quit' or u == 'q':
            self.reset()
            return "Thank you for using"

        elif u == 'hello' or u == 'hi':
            self.reset()
            return "what can i do for you"

        elif u == 'thank you' or u == 'thanks' or u == 'thank you very much':
            self.reset()
            return 'you are welcome'

        else:
            # encode
            u_ent = self.et.extract_entities(u)
            u_ent_features = self.et.context_features()  # 5
            u_emb = self.emb.encode(u)  # 300
            u_bow = self.bow_enc.encode(u)  # 60
            # concat features
            features = np.concatenate((u_ent_features, u_emb, u_bow), axis=0)

            # get action mask
            action_mask = self.at.action_mask()
            # action_mask = np.ones(self.net.action_size)

            # forward
            prediction = self.net.forward(features, action_mask)
            response = self.action_templates[prediction]
            if prediction == 0:
                slot_values = copy.deepcopy(self.et.entities)
                slot_values.pop('<location>')
                memory = ', '.join(slot_values.values())
                response = response.replace("memory", memory)
                self.reset()
                print('API CALL execute successfully and begin new session')
            if prediction == 1:
                response = response.replace("location",
                                            self.et.entities['<location>'])
            return response
예제 #11
0
    def evaluate(self, eval=False):
        ###################################################################
        if self.lang_type == 'eng':
            from modules.entities import EntityTracker
            from modules.data_utils import Data
            from modules.actions import ActionTracker
            from modules.bow import BoW_encoder

        elif self.lang_type == 'kor':
            from modules.entities_kor import EntityTracker
            from modules.data_utils_kor import Data
            from modules.actions_kor import ActionTracker
            from modules.bow_kor import BoW_encoder
        ###################################################################

        et = EntityTracker()
        at = ActionTracker(et)
        # only for evaluation purpose
        if eval:
            self.net.restore()
        # reset entities extractor

        turn_accuracy = 0.
        dialog_accuracy = 0.
        for dialog_idx in self.dialog_indices_dev:
            start, end = dialog_idx['start'], dialog_idx['end']
            dialog = self.dataset[start:end]
            num_dev_examples = len(self.dialog_indices_dev)

            et = EntityTracker()
            at = ActionTracker(et)
            # reset network_type before evaluate.
            self.net.reset_state()

            correct_examples = 0
            for (u, r) in dialog:
                u_ent = et.extract_entities(u)
                u_ent_features = et.context_features()
                u_bow = self.bow_enc.encode(u)
                if self.is_emb:
                    u_emb = self.emb.encode(u)
                if self.is_action_mask:
                    action_mask = at.action_mask()

                # concatenated features
                if self.is_action_mask and self.is_emb:
                    features = np.concatenate(
                        (u_ent_features, u_emb, u_bow, action_mask), axis=0)
                elif self.is_action_mask and not (self.is_emb):
                    features = np.concatenate(
                        (u_ent_features, u_bow, action_mask), axis=0)
                elif not (self.is_action_mask) and self.is_emb:
                    features = np.concatenate((u_ent_features, u_emb, u_bow),
                                              axis=0)
                elif not (self.is_action_mask) and not (self.is_emb):
                    features = np.concatenate((u_ent_features, u_bow), axis=0)

                if self.is_action_mask:
                    probs, prediction = self.net.forward(features, action_mask)
                else:
                    probs, prediction = self.net.forward(features)

                correct_examples += int(prediction == r)

            turn_accuracy += correct_examples / len(dialog)

            accuracy = correct_examples / len(dialog)
            if (accuracy == 1.0):
                dialog_accuracy += 1

        turn_accuracy = turn_accuracy / num_dev_examples
        dialog_accuracy = dialog_accuracy / num_dev_examples

        return turn_accuracy, dialog_accuracy