Exemplo n.º 1
0
    def update(self, user_act=None):
        """Update the dialog state."""
        if type(user_act) is not str:
            raise Exception(
                'Expected user_act to be <class \'str\'> type, but get {}.'.
                format(type(user_act)))
        prev_state = self.state
        if not os.path.exists(os.path.join(self.data_dir, "results")):
            os.makedirs(os.path.join(self.data_dir, "results"))

        global train_batch_size

        model_variables = self.model_variables
        (user, sys_res, no_turns, user_uttr_len, sys_uttr_len, labels,
         domain_labels, domain_accuracy, slot_accuracy, value_accuracy,
         value_f1, train_step, keep_prob, predictions, true_predictions,
         [y, _]) = model_variables

        # generate fake dialogue based on history (this os to reuse the original MDBT code)
        # actual_history = prev_state['history']  # [[sys, user], [sys, user], ...]
        actual_history = copy.deepcopy(
            prev_state['history'])  # [[sys, user], [sys, user], ...]
        actual_history = [['null']]
        actual_history[-1].append(user_act)
        actual_history = self.normalize_history(actual_history)
        if len(actual_history) == 0:
            actual_history = [[
                '', user_act if len(user_act) > 0 else 'fake user act'
            ]]
        fake_dialogue = {}
        turn_no = 0
        for _sys, _user in actual_history:
            turn = {}
            turn['system'] = _sys
            fake_user = {}
            fake_user['text'] = _user
            fake_user['belief_state'] = default_state()['belief_state']
            turn['user'] = fake_user
            key = str(turn_no)
            fake_dialogue[key] = turn
            turn_no += 1
        context, actual_context = process_history([fake_dialogue],
                                                  self.word_vectors,
                                                  self.ontology)
        batch_user, batch_sys, batch_labels, batch_domain_labels, batch_user_uttr_len, batch_sys_uttr_len, \
                batch_no_turns = generate_batch(context, 0, 1, len(self.ontology))  # old feature

        # run model
        [pred, y_pred] = self.sess.run(
            [predictions, y],
            feed_dict={
                user: batch_user,
                sys_res: batch_sys,
                labels: batch_labels,
                domain_labels: batch_domain_labels,
                user_uttr_len: batch_user_uttr_len,
                sys_uttr_len: batch_sys_uttr_len,
                no_turns: batch_no_turns,
                keep_prob: 1.0
            })

        # convert to str output
        dialgs, _, _ = track_dialogue(actual_context, self.ontology, pred,
                                      y_pred)
        assert len(dialgs) >= 1
        last_turn = dialgs[0][-1]
        predictions = last_turn['prediction']
        new_belief_state = copy.deepcopy(prev_state['belief_state'])

        # updaet belief state
        for item in predictions:
            item = item.lower()
            domain, slot, value = item.strip().split('-')
            value = value[::-1].split(':', 1)[1][::-1]
            if slot == 'price range':
                slot = 'pricerange'
            if slot not in ['name', 'book']:
                if domain not in new_belief_state:
                    raise Exception(
                        'Error: domain <{}> not in belief state'.format(
                            domain))
                slot = REF_SYS_DA[domain.capitalize()].get(slot, slot)
                assert 'semi' in new_belief_state[domain]
                assert 'book' in new_belief_state[domain]
                if 'book' in slot:
                    assert slot.startswith('book ')
                    slot = slot.strip().split()[1]
                domain_dic = new_belief_state[domain]
                if slot in domain_dic['semi']:
                    new_belief_state[domain]['semi'][slot] = normalize_value(
                        self.value_dict, domain, slot, value)
                elif slot in domain_dic['book']:
                    new_belief_state[domain]['book'][slot] = value
                elif slot.lower() in domain_dic['book']:
                    new_belief_state[domain]['book'][slot.lower()] = value
                else:
                    with open('mdbt_unknown_slot.log', 'a+') as f:
                        f.write(
                            'unknown slot name <{}> with value <{}> of domain <{}>\nitem: {}\n\n'
                            .format(slot, value, domain, item))
        new_request_state = copy.deepcopy(prev_state['request_state'])
        # update request_state
        user_request_slot = self.detect_requestable_slots(user_act)
        for domain in user_request_slot:
            for key in user_request_slot[domain]:
                if domain not in new_request_state:
                    new_request_state[domain] = {}
                if key not in new_request_state[domain]:
                    new_request_state[domain][key] = user_request_slot[domain][
                        key]
        # update state
        new_state = copy.deepcopy(dict(prev_state))
        new_state['belief_state'] = new_belief_state
        new_state['request_state'] = new_request_state
        self.state = new_state
        return self.state
Exemplo n.º 2
0
    def test(self, sess):
        """Test the MDBT model on mdbt dataset. Almost the same as original code."""
        if not os.path.exists("../../data/mdbt/results"):
            os.makedirs("../../data/mdbt/results")

        global train_batch_size, MODEL_URL, GRAPH_URL

        model_variables = self.model_variables
        (user, sys_res, no_turns, user_uttr_len, sys_uttr_len, labels,
         domain_labels, domain_accuracy, slot_accuracy, value_accuracy,
         value_f1, train_step, keep_prob, predictions, true_predictions,
         [y, _]) = model_variables
        [precision, recall, value_f1] = value_f1
        # print("\tMDBT: Loading from an existing model {} ....................".format(MODEL_URL))

        iterations = math.ceil(self.no_dialogues / train_batch_size)
        batch_size = train_batch_size
        [slot_acc,
         tot_accuracy] = [np.zeros(len(self.ontology), dtype="float32"), 0]
        slot_accurac = 0
        # value_accurac = np.zeros((len(slots),), dtype="float32")
        value_accurac = 0
        joint_accuracy = 0
        f1_score = 0
        preci = 0
        recal = 0
        processed_dialogues = []
        # np.set_printoptions(threshold=np.nan)
        for batch_id in range(int(iterations)):

            if batch_id == iterations - 1:
                batch_size = self.no_dialogues - batch_id * train_batch_size

            batch_user, batch_sys, batch_labels, batch_domain_labels, batch_user_uttr_len, batch_sys_uttr_len, \
            batch_no_turns = generate_batch(self.dialogues, batch_id, batch_size, len(self.ontology))

            [da, sa, va, vf, pr, re, pred, true_pred, y_pred] = sess.run(
                [
                    domain_accuracy, slot_accuracy, value_accuracy, value_f1,
                    precision, recall, predictions, true_predictions, y
                ],
                feed_dict={
                    user: batch_user,
                    sys_res: batch_sys,
                    labels: batch_labels,
                    domain_labels: batch_domain_labels,
                    user_uttr_len: batch_user_uttr_len,
                    sys_uttr_len: batch_sys_uttr_len,
                    no_turns: batch_no_turns,
                    keep_prob: 1.0
                })

            true = sum([
                1 if np.array_equal(pred[k, :], true_pred[k, :])
                and sum(true_pred[k, :]) > 0 else 0
                for k in range(true_pred.shape[0])
            ])
            actual = sum([
                1 if sum(true_pred[k, :]) > 0 else 0
                for k in range(true_pred.shape[0])
            ])
            ja = true / actual
            tot_accuracy += da
            # joint_accuracy += ja
            slot_accurac += sa
            if math.isnan(pr):
                pr = 0
            preci += pr
            recal += re
            if math.isnan(vf):
                vf = 0
            f1_score += vf
            # value_accurac += va
            slot_acc += np.mean(np.asarray(np.equal(pred, true_pred),
                                           dtype="float32"),
                                axis=0)

            dialgs, va1, ja = track_dialogue(
                self.actual_dialogues[batch_id * train_batch_size:batch_id *
                                      train_batch_size + batch_size],
                self.ontology, pred, y_pred)
            processed_dialogues += dialgs
            joint_accuracy += ja
            value_accurac += va1

            print(
                "The accuracies for domain is {:.2f}, slot {:.2f}, value {:.2f}, other value {:.2f}, f1_score {:.2f} precision {:.2f}"
                " recall {:.2f}  for batch {}".format(da, sa, np.mean(va), va1,
                                                      vf, pr, re, batch_id))

        print(
            "End of evaluating the test set..........................................................................."
        )

        slot_acc /= iterations
        # print("The accuracies for each slot:")
        # print(value_accurac/iterations)
        print("The overall accuracies for domain is"
              " {}, slot {}, value {}, f1_score {}, precision {},"
              " recall {}, joint accuracy {}".format(
                  tot_accuracy / iterations, slot_accurac / iterations,
                  value_accurac / iterations, f1_score / iterations,
                  preci / iterations, recal / iterations,
                  joint_accuracy / iterations))

        with open(self.results_url, 'w') as f:
            json.dump(processed_dialogues, f, indent=4)