Пример #1
0
def get_delex_dict(data_path):
    delex_dict = {}
    train_time_dict = {}
    for domain in ontology.all_domains:
        if domain == "taxi":
            for car in ["toyota", "skoda", "bmw", "honda", "ford", "audi", "lexus", "volvo", "volkswagen", "tesla"]:
                delex_dict[" {} ".format(car)] = " [taxi_car] "
        else:
            db = json.load(open(os.path.join(data_path, "{}_db.json".format(domain)), "r"))
            for idx, entry in enumerate(db):
                for slot, value in entry.items():
                    if slot == "location" or slot == "price" and domain == "hotel":
                        continue
                    slot, value = clean_data.clean_slot_values(domain, slot, value)
                    if slot == "price":
                        if value in ["?", "free"]:
                            continue
                        delex_dict[" {} ".format(value)] = " [price] "
                    elif slot in ["address", "food", "postcode", "duration", "name"]:
                        delex_dict[" {} ".format(value)] = " [{}_{}] ".format(domain, slot)
                    elif slot == "stars":
                        rules = ["{} star", "{}-star", "rating of {}", "star of {}"]
                        for rule in rules:
                            delex_dict[rule.format(value)] = rule.format("[hotel_stars]")
                    elif slot == "id" and domain == "train":
                        delex_dict[" {} ".format(value)] = " [train_id] "
                    elif slot in ["arrive", "leave"]:
                        train_time_dict[" {} ".format(value)] = " [train_time] "  # replace to arrive or leave after delexicalization during creating data
                    elif slot == "pricerange":
                        if value in ["?", "free"]:
                            continue
                        delex_dict[" {} ".format(value)] = " [pricerange] "
                    elif slot == "type" and domain in ["hotel", "attraction"]:
                        delex_dict[" {} ".format(value)] = " [{}_type] ".format(domain)
                    elif slot == "area":
                        delex_dict[" {} ".format(value)] = " [area] "
                    elif slot == "phone":
                        delex_dict[" {} ".format(value)] = " [phone] "
                    elif slot == "department":
                        delex_dict[" {} ".format(value)] = " [hospital_phone] "

    reference_nums = get_references(data_path)
    for value in reference_nums:
        domain, num = value.split("-")
        delex_dict[" {} ".format(num)] = " [{}_reference] ".format(domain)

    return delex_dict, train_time_dict
Пример #2
0
    def parse_goal(self, dial_id):
        inform_goal = {}
        request_goal = {}
        book_goal = {}

        if dial_id in self.dev_list:
            goal = self.dev[dial_id]["goal"]
        elif dial_id in self.test_list:
            goal = self.test[dial_id]["goal"]
        elif dial_id in self.train_list:
            goal = self.train[dial_id]["goal"]

        for domain, goals in goal.items():
            for category, slots in goals.items():
                if category == "info":
                    inform_goal[domain] = {}
                    for slot, value in slots.items():
                        slot, value = clean_slot_values(domain, slot, value)
                        inform_goal[domain][slot] = value
                    if inform_goal[domain] == {}:
                        inform_goal.pop(domain)
                elif category == "reqt":
                    request_goal[domain] = []
                    for slot in slots:
                        slot = slot.lower()
                        if ontology.normlize_slot_names.get(slot):
                            slot = ontology.normlize_slot_names[slot]
                        if slot in self.requestables:  # evaluate only 5 request slots
                            request_goal[domain].append(slot)
                    if len(request_goal[domain]) == 0:
                        request_goal.pop(domain)
                elif category == "book":
                    if request_goal.get(domain):
                        request_goal[domain].append("reference")
                    else:
                        request_goal[domain] = ["reference"]
                    # book_goal[domain] = {}
                    # for slot, value in slots.items():
                    #     if slot in ["pre_invalid", "invalid"]:
                    #         continue
                    #     slot, value = clean_slot_values(domain, slot, value)
                    #     book_goal[domain][slot] = value
                    # if book_goal[domain] == {}:
                    #     book_goal.pop(domain)

        return inform_goal, request_goal
Пример #3
0
    def create_data(self):
        data = {}
        train = {}
        dev = {}
        test = {}
        ignore_list = ["SNG1213", "PMUL0382", "PMUL0237"]
        logger.info("Processing data...")
        for dial_id, dial in tqdm(self.data.items()):
            dial_id = dial_id.split(".")[0]
            if dial_id in ignore_list:
                continue
            dialogue = {}
            goal = {}
            dial_domains = []
            for key, value in dial["goal"].items():  # process user's goal
                if key in ontology.all_domains and value != {}:
                    if value.get("reqt"):  # normalize requestable slot names
                        for idx, slot in enumerate(value["reqt"]):
                            if ontology.normlize_slot_names.get(slot):
                                value["reqt"][
                                    idx] = ontology.normlize_slot_names[slot]
                    goal[key] = value
                    dial_domains.append(key)
            if len(dial_domains) == 0:  # ignore police and hospital
                ignore_list.append(dial_id)
                continue
            dialogue["goal"] = goal

            dialogue["log"] = []
            acts = self.acts[dial_id]
            turn = {}
            for turn_num, turn_dial in enumerate(dial["log"]):
                meta_data = turn_dial["metadata"]
                if meta_data == {}:  # user turn
                    turn["turn_num"] = int(turn_num / 2)
                    turn["user"] = clean_text(turn_dial["text"])
                else:  # system turn
                    turn["response"] = clean_text(turn_dial["text"])
                    response_ = clean_text(turn_dial["text"])
                    for k, v in self.delex_dict.items():
                        response_ = response_.replace(k,
                                                      v)  # delexicalize values

                    is_train = False
                    for token in [
                            "train", "arrive", "arrives", "arrived",
                            "arriving", "arrival", "destination", "reach",
                            "leave", "leaves", "leaving", "leaved", "depart",
                            "departing", "departs", "departure", "[train_"
                    ]:
                        if token in response_:
                            is_train = True
                            break
                    if is_train:
                        for k, v in self.train_time_dict.items():
                            response_ = response_.replace(
                                k, v)  # delexicalize train times

                    response_ = re.sub("(\d\s?){11}", "[phone]",
                                       response_)  # delexicalize phone number

                    while response_.find(
                            "[train_time]"
                    ) != -1:  # replace [train_time] to [train_arrive] or [train_leave] by rule
                        response_split = response_.split()
                        idx = response_split.index("[train_time]")
                        replaced = False
                        for token in response_split[:idx][::-1]:
                            if token in [
                                    "arrive", "arrives", "arrived", "arriving",
                                    "arrival", "destination", "reach", "by",
                                    "before", "have", "to"
                            ]:
                                response_split[idx] = "[train_arrive]"
                                replaced = True
                                break
                            elif token in [
                                    "leave", "leaves", "leaving", "leaved",
                                    "depart", "departing", "departs",
                                    "departure", "from", "after", "earlier",
                                    "there"
                            ]:
                                response_split[idx] = "[train_leave]"
                                replaced = True
                                break
                        if not replaced:
                            response_split[idx] = "[train_leave]"
                        response_ = " ".join(response_split)

                    turn["response_delex"] = response_

                    belief = {}
                    gate = {}
                    act = {}

                    for domain in dial_domains:  # active domains of dialogue
                        for slot, value in meta_data[domain]["book"].items(
                        ):  # book
                            if slot == "booked":
                                continue
                            slot, value = clean_slot_values(
                                domain, slot, value)
                            if value != "":
                                belief["{}-{}".format(domain, slot)] = value
                                gate["{}-{}".format(
                                    domain, slot
                                )] = ontology.gate_dict[
                                    value] if value == "dontcare" else ontology.gate_dict[
                                        "prediction"]
                        for slot, value in meta_data[domain]["semi"].items(
                        ):  # semi
                            slot, value = clean_slot_values(
                                domain, slot, value)
                            if value != "":
                                belief["{}-{}".format(domain, slot)] = value
                                gate["{}-{}".format(
                                    domain, slot
                                )] = ontology.gate_dict[
                                    value] if value == "dontcare" else ontology.gate_dict[
                                        "prediction"]
                    turn["belief"] = belief
                    turn["gate"] = gate

                    if acts.get(str(turn["turn_num"] + 1)) and type(
                            acts.get(str(turn["turn_num"] +
                                         1))) != str:  # mapping system action
                        for domain_act, slots in acts[str(turn["turn_num"] +
                                                          1)].items():
                            act_temp = []
                            for slot in slots:  # slot: [slot, value]
                                slot_, value_ = clean_slot_values(
                                    domain_act.split("-")[0], slot[0], slot[1])
                                if slot_ == "none" or value_ in [
                                        "?", "none"
                                ]:  # general domain or request slot or parking
                                    act_temp.append(slot_)
                                else:
                                    act_temp.append("{}-{}".format(
                                        slot_, value_))
                            act[domain_act.lower()] = act_temp
                    turn["action"] = act

                    dialogue["log"].append(turn)
                    turn = {}  # clear turn

            data[dial_id] = dialogue

        logger.info("Processing finished.")
        logger.info("Dividing data to train/dev/test...")
        for dial_id in self.train_list:
            dial_id = dial_id.split(".")[0]
            if dial_id not in ignore_list:
                train[dial_id] = data[dial_id]
        for dial_id in self.dev_list:
            dial_id = dial_id.split(".")[0]
            if dial_id not in ignore_list:
                dev[dial_id] = data[dial_id]
        for dial_id in self.test_list:
            dial_id = dial_id.split(".")[0]
            if dial_id not in ignore_list:
                test[dial_id] = data[dial_id]
        logger.info("Dividing finished.")

        value_ontology = json.load(
            open(os.path.join(self.data_path, "ontology.json"), "r"))
        value_ontology_processed = {}

        logger.info("Processing ontology...")
        for domain_slot, values in value_ontology.items():
            domain = domain_slot.split("-")[0]
            slot = domain_slot.split("-")[2].lower()
            if ontology.normlize_slot_names.get(slot):
                slot = ontology.normlize_slot_names[slot]
            domain_slot = "-".join([domain, slot])
            value_ontology_processed[domain_slot] = []
            for value in values:
                _, value = clean_slot_values(domain, slot, value)
                value_ontology_processed[domain_slot].append(value)
        with open(os.path.join(data_path, "ontology_processed.json"),
                  "w") as f:
            json.dump(value_ontology_processed, f, indent=2)
        logger.info("Ontology was processed.")

        return train, dev, test
    def train_forward(self, turn_inputs, turn_contexts, turn_context_lengths, action_history, teacher_forcing):
        batch_size = turn_contexts.size(0)

        user = turn_inputs["usr"]
        history = turn_inputs["context"]
        for batch_idx in range(batch_size):
            self.states[batch_idx]["history"].append(["usr", turn_inputs["usr"][batch_idx]])

            input_action = self.nlu.predict(user[batch_idx], history[batch_idx])  # [[intent, domain, slot, value], ..., [intent, domain, slot, value]]
            self.states[batch_idx]["user_action"] = input_action
            action_history[batch_idx] += input_action

            # update request state
            for action in input_action:
                intent, domain, slot, _ = action
                slot, _ = clean_slot_values(domain, slot, _)
                domain = domain.lower()
                if intent == "Request":
                    if not self.states[batch_idx]["request_state"].get(domain):
                        self.states[batch_idx]["request_state"][domain] = {}
                    self.states[batch_idx]["request_state"][domain][slot] = 0

        # calculate score of input action generated by BERT NLU
        encoded_action_history = []
        history_lengths = []
        for batch_idx, actions in enumerate(action_history):
            encoded_actions = []
            for action in actions:
                intent, domain, slot, value = action
                slot = ontology.normlize_slot_names.get(slot, slot)
                encoded_actions += self.vocab.encode(" ".join([intent, domain, slot, value]))[1:-1]
            encoded_actions = [self.vocab.word2idx["<bos>"]] + encoded_actions + [self.vocab.word2idx["<eos>"]]
            history_lengths.append(len(encoded_actions))
            encoded_action_history.append(encoded_actions)
        history_lengths = torch.tensor(history_lengths).cuda()
        encoded_action_history_ = torch.zeros((batch_size, history_lengths.max().item()), dtype=torch.int64).cuda()
        for batch_idx, actions in enumerate(encoded_action_history):
            encoded_action_history_[batch_idx, :history_lengths[batch_idx]] = torch.tensor(actions)

        action_history_outputs, action_history_hidden = self.context_encoder(encoded_action_history_, history_lengths)  # [batch, len, hidden], [layers, batch, hidden]
        encoder_outputs_, encoder_hidden_ = self.context_encoder(turn_contexts, turn_context_lengths)  # [batch, len, hidden], [layers, batch, hidden]

        action_score = torch.sigmoid(torch.matmul(action_history_hidden[-1].unsqueeze(dim=1), encoder_hidden_[-1].unsqueeze(dim=2)))  # [batch, 1, 1]

        # action score for attention score
        action_score_attenion = torch.zeros(batch_size, (turn_context_lengths+history_lengths).max().item()).cuda()  # [batch, len]
        for batch_idx, context_len in enumerate(turn_context_lengths):
            action_score_attenion[batch_idx, :context_len] = 1-action_score.squeeze()[batch_idx]
            action_score_attenion[batch_idx, context_len:context_len+history_lengths[batch_idx].item()] = action_score.squeeze()[batch_idx]

        # weighted sum & weighted concat
        encoder_outputs, turn_contexts, turn_context_lengths = self.concat(encoder_outputs_, action_history_outputs, turn_contexts, turn_context_lengths, \
            encoded_action_history_, history_lengths)
        action_score = action_score.transpose(0,1).contiguous()
        encoder_hidden = (1 - action_score) * encoder_hidden_ + action_score * action_history_hidden  # [layers, batch, hidden]

        gate_outputs, all_probs, all_pred_words = self.belief_decoder(encoder_outputs, encoder_hidden, turn_contexts, turn_context_lengths, \
            turn_inputs["belief"], teacher_forcing, action_score_attenion)  # [batch, slots, 3], [batch, slots, len, vocab], [batch, slots, len], [batch, slots]
        
        gate_preds = gate_outputs.argmax(dim=2)

        # prev_gate = turn_inputs.get("prev_gate")

        max_value_len = 0
        belief_gen = []  # [batch, slots, len]
        for batch_idx, batch in enumerate(all_pred_words):
            belief_gen_ = []  # [slots, len]
            for slot_idx, pred_words in enumerate(batch):
                if gate_preds[batch_idx, slot_idx].item() == ontology.gate_dict["none"]:
                    belief_gen_.append(self.vocab.encode("none")[1:])
                    len_ = len(self.vocab.encode("none")[1:])
                elif gate_preds[batch_idx, slot_idx].item() == ontology.gate_dict["don't care"]:
                    belief_gen_.append(self.vocab.encode("don't care")[1:])
                    len_ = len(self.vocab.encode("don't care")[1:])
                else:
                    for idx, value in enumerate(pred_words):
                        if value == self.vocab.word2idx["<eos>"]:
                            break
                    belief_gen_.append(pred_words[:idx+1].tolist())
                    len_ = idx + 1
                max_value_len = max(max_value_len, len_)
            belief_gen.append(belief_gen_)

        gate_label = turn_inputs["gate"]
        gate_loss = F.cross_entropy(gate_outputs.view(-1, 3), gate_label.view(-1))

        # if prev_gate is not None:
        #     turn_domain = self.make_turn_domain(prev_gate, gate_preds)  # [batch]
        # else:
        #     turn_domain = self.make_turn_domain(None, gate_preds)

        # prev_gate = gate_preds.detach()  # [batch, slots]

        acc_belief = torch.ones(batch_size, len(ontology.all_info_slots)).cuda()

        gate_mask = (gate_label != gate_preds)
        acc_belief.masked_fill_(gate_mask, value=0)  # fail to predict gate

        value_label = turn_inputs["belief"]
        value_label_lengths = torch.zeros(batch_size, len(ontology.all_info_slots), dtype=torch.int64).cuda()
        for batch_idx, batch in enumerate(value_label):
            for slot_idx, pred_words in enumerate(batch):
                value = pred_words[pred_words != self.vocab.word2idx["<pad>"]].tolist()  # remove padding
                value_label_lengths[batch_idx, slot_idx] = len(value)
                if value != belief_gen[batch_idx][slot_idx]:
                    acc_belief[batch_idx, slot_idx] = 0  # fail to predict value

        if teacher_forcing:
            value_loss = masked_cross_entropy_for_value(all_probs, value_label, value_label_lengths)
        else:
            min_len = min(value_label.size(2), all_probs.size(2))
            value_loss = masked_cross_entropy_for_value(all_probs[:, :, :min_len, :].contiguous(), value_label[:, :, :min_len].contiguous(), value_label_lengths)

        ### make state for Rule policy
        # for batch_idx, belief in enumerate(belief_gen):
        #     for slot_idx, slot in enumerate(belief):
        #         value = self.vocab.decode(slot[:-1])
        #         domain, slot = ontology.all_info_slots[slot_idx].split("-")

        #         # book slots
        #         for slot_ in self.states[batch_idx]["belief_state"][domain]["book"].keys():
        #             if slot_ == "booked":
        #                 continue
        #             slot_ = ontology.normlize_slot_names.get(slot_, slot_)
        #             if slot_ == slot:
        #                 self.states[batch_idx]["belief_state"][domain]["book"][slot_] = value

        #         # semi slots
        #         for slot_ in self.states[batch_idx]["belief_state"][domain]["semi"].keys():
        #             slot_ = ontology.normlize_slot_names.get(slot_, slot_)
        #             if slot_ == slot:
        #                 self.states[batch_idx]["belief_state"][domain]["semi"][slot_] = value

        # ### policy
        # output_actions = []
        # for batch_idx, state in enumerate(self.states):
        #     output_actions.append(self.policy.predict(state))

        # ### NLG
        # model_responses = []
        # for batch_idx, output_action in enumerate(output_actions):
        #     model_responses.append(self.nlg.generate(output_actions[batch_idx]))

        return gate_loss, value_loss, acc_belief, belief_gen, action_history
Пример #5
0
    def create_data(self):
        data = {}
        train = {}
        dev = {}
        test = {}
        ignore_list = ["SNG1213", "PMUL0382", "PMUL0237"]
        logger.info("Processing data...")
        for dial_id, dial in tqdm(self.data.items()):
            dial_id = dial_id.split(".")[0]
            if dial_id in ignore_list:
                continue
            dialogue = {}
            goal = {}
            dial_domains = []
            for key, value in dial["goal"].items():  # process user's goal
                if key in ontology.all_domains and value != {}:
                    if value.get("reqt"):  # normalize requestable slot names
                        for idx, slot in enumerate(value["reqt"]):
                            if ontology.normlize_slot_names.get(slot):
                                value["reqt"][idx] = ontology.normlize_slot_names[slot]
                    goal[key] = value
                    dial_domains.append(key)
            if len(dial_domains) == 0:  # ignore police and hospital
                ignore_list.append(dial_id)
                continue
            dialogue["goal"] = goal
        
            dialogue["log"] = []
            acts = self.acts[dial_id]
            turn = {}
            for turn_num, turn_dial in enumerate(dial["log"]):
                meta_data = turn_dial["metadata"]
                if meta_data == {}:  # user turn
                    turn["turn_num"] = int(turn_num/2)
                    turn["user"] = clean_text(turn_dial["text"])
                else:  # system turn
                    turn["response"] = clean_text(turn_dial["text"])
                    belief = {}
                    gate = {}
                    act = {}

                    for domain in dial_domains:  # active domains of dialogue
                        for slot, value in meta_data[domain]["book"].items():  # book
                            if slot == "booked":
                                continue
                            slot, value = clean_slot_values(domain, slot, value)
                            if value != "":
                                belief["{}-{}".format(domain, slot)] = value
                                gate["{}-{}".format(domain, slot)] = ontology.gate_dict[value] if value == "don't care" else ontology.gate_dict["prediction"]
                        for slot, value in meta_data[domain]["semi"].items():  # semi
                            slot, value = clean_slot_values(domain, slot, value)
                            if value != "":
                                belief["{}-{}".format(domain, slot)] = value
                                gate["{}-{}".format(domain, slot)] = ontology.gate_dict[value] if value == "don't care" else ontology.gate_dict["prediction"]
                    turn["belief"] = belief
                    turn["gate"] = gate

                    if acts.get(str(turn["turn_num"]+1)) and type(acts.get(str(turn["turn_num"]+1))) != str:  # mapping system action
                        for domain_act, slots in acts[str(turn["turn_num"]+1)].items():
                            act_temp = []
                            for slot in slots:  # slot: [slot, value]
                                slot_, value_ = clean_slot_values(domain_act.split("-")[0], slot[0], slot[1])
                                if slot_ == "none" or value_ in  ["?", "none"]:  # general domain or request slot or parking
                                    act_temp.append(slot_)
                                else:
                                    act_temp.append("{}-{}".format(slot_, value_))
                            act[domain_act.lower()] = act_temp
                    turn["action"] = act

                    dialogue["log"].append(turn)
                    turn = {}  # clear turn
            
            data[dial_id] = dialogue

        logger.info("Processing finished.")
        logger.info("Dividing data to train/dev/test...")
        for dial_id in self.train_list:
            dial_id = dial_id.split(".")[0]
            if dial_id not in ignore_list:
                train[dial_id] = data[dial_id]
        for dial_id in self.dev_list:
            dial_id = dial_id.split(".")[0]
            if dial_id not in ignore_list:
                dev[dial_id] = data[dial_id]
        for dial_id in self.test_list:
            dial_id = dial_id.split(".")[0]
            if dial_id not in ignore_list:
                test[dial_id] = data[dial_id]
        logger.info("Dividing finished.")

        value_ontology = json.load(open(os.path.join(self.data_path, "ontology.json"), "r"))
        value_ontology_processed = {}

        logger.info("Processing ontology...")
        for domain_slot, values in value_ontology.items():
            domain = domain_slot.split("-")[0]
            slot = domain_slot.split("-")[2].lower()
            if ontology.normlize_slot_names.get(slot):
                slot = ontology.normlize_slot_names[slot]
            domain_slot = "-".join([domain, slot])
            value_ontology_processed[domain_slot] = []
            for value in values:
                _, value = clean_slot_values(domain, slot, value)
                value_ontology_processed[domain_slot].append(value)
        with open(os.path.join(data_path, "ontology_processed.json"), "w") as f:
            json.dump(value_ontology_processed, f, indent=2)
        logger.info("Ontology was processed.")

        return train, dev, test