class DIALOG(nn.Module):
    def __init__(self, vocab, db, config):
        super(DIALOG, self).__init__()
        self.vocab = vocab
        self.db = db
        self.pointer_size = config.pointer_size
        self.max_belief_len = config.max_belief_len
        self.nlu = BERTNLU()  # fixed, not finetuning
        self.context_encoder = ContextEncoder(vocab.vocab_size, config.hidden_size, config.hidden_size, config.dropout, config.num_layers, vocab.word2idx["<pad>"])
        self.belief_decoder = BeliefDecoder(vocab, self.context_encoder.embedding, config.hidden_size, config.dropout, config.num_layers, config.max_value_len)
        self.policy = RulePolicy()
        self.nlg = TemplateNLG(is_user=False)
        # self.action_decoder = ActionDecoder(vocab, self.context_encoder.embedding, config.hidden_size, config.pointer_size, config.dropout, config.num_layers, config.max_act_len)
        # self.response_decoder = ResponseDecoder(vocab, self.context_encoder.embedding, config.hidden_size, config.pointer_size, config.dropout, config.num_layers, config.max_sentence_len)
        self.load_embedding()  # load Glove & Kazuma embedding

    def forward(self, turn_inputs, turn_contexts, turn_context_lengths, action_history, teacher_forcing, test=False):
        """
        turn_inputs: {
            "user": [batch, len]
            "response": [batch, len]
            "belief": [batch, slots, len]
            "gate": [batch, slots]
            "action": [batch, len]
            "usr": [batch] => string list
            "context: [batch, turns, 2] => string list
            "prev_gate": [batch, slots]
        }
        turn_contexts: [batch, len]
        turn_context_lengths: [batch]
        action_history: [batch, actions, 4] => string list
        """

        if test:
            return self.test_forward()
        else:
            return self.train_forward(turn_inputs, turn_contexts, turn_context_lengths, action_history, teacher_forcing)

    def train_forward(self, turn_inputs, turn_contexts, turn_context_lengths, action_history, teacher_forcing):
        batch_size = turn_contexts.size(0)

        user = turn_inputs["usr"]
        history = turn_inputs["context"]
        for batch_idx in range(batch_size):
            self.states[batch_idx]["history"].append(["usr", turn_inputs["usr"][batch_idx]])

            input_action = self.nlu.predict(user[batch_idx], history[batch_idx])  # [[intent, domain, slot, value], ..., [intent, domain, slot, value]]
            self.states[batch_idx]["user_action"] = input_action
            action_history[batch_idx] += input_action

            # update request state
            for action in input_action:
                intent, domain, slot, _ = action
                slot, _ = clean_slot_values(domain, slot, _)
                domain = domain.lower()
                if intent == "Request":
                    if not self.states[batch_idx]["request_state"].get(domain):
                        self.states[batch_idx]["request_state"][domain] = {}
                    self.states[batch_idx]["request_state"][domain][slot] = 0

        # calculate score of input action generated by BERT NLU
        encoded_action_history = []
        history_lengths = []
        for batch_idx, actions in enumerate(action_history):
            encoded_actions = []
            for action in actions:
                intent, domain, slot, value = action
                slot = ontology.normlize_slot_names.get(slot, slot)
                encoded_actions += self.vocab.encode(" ".join([intent, domain, slot, value]))[1:-1]
            encoded_actions = [self.vocab.word2idx["<bos>"]] + encoded_actions + [self.vocab.word2idx["<eos>"]]
            history_lengths.append(len(encoded_actions))
            encoded_action_history.append(encoded_actions)
        history_lengths = torch.tensor(history_lengths).cuda()
        encoded_action_history_ = torch.zeros((batch_size, history_lengths.max().item()), dtype=torch.int64).cuda()
        for batch_idx, actions in enumerate(encoded_action_history):
            encoded_action_history_[batch_idx, :history_lengths[batch_idx]] = torch.tensor(actions)

        action_history_outputs, action_history_hidden = self.context_encoder(encoded_action_history_, history_lengths)  # [batch, len, hidden], [layers, batch, hidden]
        encoder_outputs_, encoder_hidden_ = self.context_encoder(turn_contexts, turn_context_lengths)  # [batch, len, hidden], [layers, batch, hidden]

        action_score = torch.sigmoid(torch.matmul(action_history_hidden[-1].unsqueeze(dim=1), encoder_hidden_[-1].unsqueeze(dim=2)))  # [batch, 1, 1]

        # action score for attention score
        action_score_attenion = torch.zeros(batch_size, (turn_context_lengths+history_lengths).max().item()).cuda()  # [batch, len]
        for batch_idx, context_len in enumerate(turn_context_lengths):
            action_score_attenion[batch_idx, :context_len] = 1-action_score.squeeze()[batch_idx]
            action_score_attenion[batch_idx, context_len:context_len+history_lengths[batch_idx].item()] = action_score.squeeze()[batch_idx]

        # weighted sum & weighted concat
        encoder_outputs, turn_contexts, turn_context_lengths = self.concat(encoder_outputs_, action_history_outputs, turn_contexts, turn_context_lengths, \
            encoded_action_history_, history_lengths)
        action_score = action_score.transpose(0,1).contiguous()
        encoder_hidden = (1 - action_score) * encoder_hidden_ + action_score * action_history_hidden  # [layers, batch, hidden]

        gate_outputs, all_probs, all_pred_words = self.belief_decoder(encoder_outputs, encoder_hidden, turn_contexts, turn_context_lengths, \
            turn_inputs["belief"], teacher_forcing, action_score_attenion)  # [batch, slots, 3], [batch, slots, len, vocab], [batch, slots, len], [batch, slots]
        
        gate_preds = gate_outputs.argmax(dim=2)

        # prev_gate = turn_inputs.get("prev_gate")

        max_value_len = 0
        belief_gen = []  # [batch, slots, len]
        for batch_idx, batch in enumerate(all_pred_words):
            belief_gen_ = []  # [slots, len]
            for slot_idx, pred_words in enumerate(batch):
                if gate_preds[batch_idx, slot_idx].item() == ontology.gate_dict["none"]:
                    belief_gen_.append(self.vocab.encode("none")[1:])
                    len_ = len(self.vocab.encode("none")[1:])
                elif gate_preds[batch_idx, slot_idx].item() == ontology.gate_dict["don't care"]:
                    belief_gen_.append(self.vocab.encode("don't care")[1:])
                    len_ = len(self.vocab.encode("don't care")[1:])
                else:
                    for idx, value in enumerate(pred_words):
                        if value == self.vocab.word2idx["<eos>"]:
                            break
                    belief_gen_.append(pred_words[:idx+1].tolist())
                    len_ = idx + 1
                max_value_len = max(max_value_len, len_)
            belief_gen.append(belief_gen_)

        gate_label = turn_inputs["gate"]
        gate_loss = F.cross_entropy(gate_outputs.view(-1, 3), gate_label.view(-1))

        # if prev_gate is not None:
        #     turn_domain = self.make_turn_domain(prev_gate, gate_preds)  # [batch]
        # else:
        #     turn_domain = self.make_turn_domain(None, gate_preds)

        # prev_gate = gate_preds.detach()  # [batch, slots]

        acc_belief = torch.ones(batch_size, len(ontology.all_info_slots)).cuda()

        gate_mask = (gate_label != gate_preds)
        acc_belief.masked_fill_(gate_mask, value=0)  # fail to predict gate

        value_label = turn_inputs["belief"]
        value_label_lengths = torch.zeros(batch_size, len(ontology.all_info_slots), dtype=torch.int64).cuda()
        for batch_idx, batch in enumerate(value_label):
            for slot_idx, pred_words in enumerate(batch):
                value = pred_words[pred_words != self.vocab.word2idx["<pad>"]].tolist()  # remove padding
                value_label_lengths[batch_idx, slot_idx] = len(value)
                if value != belief_gen[batch_idx][slot_idx]:
                    acc_belief[batch_idx, slot_idx] = 0  # fail to predict value

        if teacher_forcing:
            value_loss = masked_cross_entropy_for_value(all_probs, value_label, value_label_lengths)
        else:
            min_len = min(value_label.size(2), all_probs.size(2))
            value_loss = masked_cross_entropy_for_value(all_probs[:, :, :min_len, :].contiguous(), value_label[:, :, :min_len].contiguous(), value_label_lengths)

        ### make state for Rule policy
        # for batch_idx, belief in enumerate(belief_gen):
        #     for slot_idx, slot in enumerate(belief):
        #         value = self.vocab.decode(slot[:-1])
        #         domain, slot = ontology.all_info_slots[slot_idx].split("-")

        #         # book slots
        #         for slot_ in self.states[batch_idx]["belief_state"][domain]["book"].keys():
        #             if slot_ == "booked":
        #                 continue
        #             slot_ = ontology.normlize_slot_names.get(slot_, slot_)
        #             if slot_ == slot:
        #                 self.states[batch_idx]["belief_state"][domain]["book"][slot_] = value

        #         # semi slots
        #         for slot_ in self.states[batch_idx]["belief_state"][domain]["semi"].keys():
        #             slot_ = ontology.normlize_slot_names.get(slot_, slot_)
        #             if slot_ == slot:
        #                 self.states[batch_idx]["belief_state"][domain]["semi"][slot_] = value

        # ### policy
        # output_actions = []
        # for batch_idx, state in enumerate(self.states):
        #     output_actions.append(self.policy.predict(state))

        # ### NLG
        # model_responses = []
        # for batch_idx, output_action in enumerate(output_actions):
        #     model_responses.append(self.nlg.generate(output_actions[batch_idx]))

        return gate_loss, value_loss, acc_belief, belief_gen, action_history

    def load_embedding(self):
        glove = GloveEmbedding()
        kazuma = KazumaCharEmbedding()
        embed = self.context_encoder.embedding.weight.data
        for word, idx in self.vocab.word2idx.items():
            embed[idx] = torch.tensor(glove.emb(word, default="zero") + kazuma.emb(word, default="zero"))
        # self.context_encoder.embedding.weight.data = embed
        # self.belief_decoder.slot_embedding.weight.data = embed

    def make_turn_domain(self, prev_gate, gate):
        batch_size = gate.size(0)
        turn_domain = torch.zeros((batch_size, len(ontology.all_domains))).cuda()
        if prev_gate is None:  # first turn
            turn_gate = (gate != ontology.gate_dict["none"]).long()
        else:
            turn_gate = (gate != prev_gate).long()  # find changed gate
        for slot_idx in range(len(ontology.all_info_slots)):
            domain, slot = ontology.all_info_slots[slot_idx].split("-")
            domain_idx = ontology.all_domains.index(domain)
            turn_domain[:, domain_idx] += turn_gate[:, slot_idx]
        turn_domain = turn_domain.argmax(dim=1).tolist()  # [batch]

        return turn_domain
        
    def parse_action(self, action):
        """
        action: [len] => list
        """

        domains = ontology.domain_action_slot["domain"]
        actions = ontology.domain_action_slot["action"]
        slots = ontology.domain_action_slot["slot"]

        parsed_action = []
        decoded_action = []

        for token in action:
            decoded_action.append(self.vocab.idx2word[token])
        
        for token in decoded_action:
            if token in domains:
                domain = token
            if token in actions:
                action = token
                if domain == "general":
                    parsed_action.append("{}-{}".format(domain, action))
            if token in slots:
                slot = token
                parsed_action.append("{}-{}-{}".format(domain, action, slot))

        return parsed_action

    def concat(self, encoder_outputs, action_history_outputs, contexts, context_lengths, action_history, history_lengths):
        batch_size = contexts.size(0)
        hidden_size = encoder_outputs.size(2)
        lengths = context_lengths + history_lengths
        new_contexts = torch.zeros(size=(batch_size, lengths.max().item()), dtype=torch.int64).cuda()
        new_encoder_outputs = torch.zeros(size=(batch_size, lengths.max().item(), hidden_size)).cuda()

        for batch_idx in range(batch_size):
            new_contexts[batch_idx, :lengths[batch_idx]] = torch.cat([contexts[batch_idx, :context_lengths[batch_idx]], \
                action_history[batch_idx, :history_lengths[batch_idx]]], dim=0)  # [batch, len]
            new_encoder_outputs[batch_idx, :lengths[batch_idx], :] = torch.cat([encoder_outputs[batch_idx, :context_lengths[batch_idx], :], \
                action_history_outputs[batch_idx, :history_lengths[batch_idx], :]], dim=0)  # [batch, len, hidden]

        return new_encoder_outputs, new_contexts, lengths
Ejemplo n.º 2
0
class Dialog(Agent):
    def __init__(self, model_file=DEFAULT_MODEL_URL, name="Dialog"):
        super(Dialog, self).__init__(name=name)
        if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data')):
            os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data'))
            ### download multiwoz data
            print('down load data from', DEFAULT_ARCHIVE_FILE_URL)
        if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save')):
            os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save'))
            ### download trained model
            print('down load model from', DEFAULT_MODEL_URL)
        model_path = ""

        config = Config()
        parser = config.parser
        config = parser.parse_args()
        with open("assets/never_split.txt") as f:
            never_split = f.read().split("\n")
        self.tokenizer = BertTokenizer("assets/vocab.txt", never_split=never_split)
        self.nlu = BERTNLU()
        self.dst_ = DST(config).cuda()
        ckpt = torch.load("save/model_Sun_Jun_21_07:08:48_2020.pt", map_location = lambda storage, loc: storage.cuda(local_rank))
        self.dst_.load_state_dict(ckpt["model"])
        self.dst_.eval()
        self.policy = RulePolicy()
        self.nlg = TemplateNLG(is_user=False)
        self.init_session()
        self.slot_mapping = {
            "leave": "leaveAt",
            "arrive": "arriveBy"
        }

    def init_session(self):
        self.nlu.init_session()
        self.policy.init_session()
        self.nlg.init_session()
        self.history = []
        self.state = default_state()
        pass

    def response(self, user):
        self.history.append(["user", user])

        user_action = []

        self.input_action = self.nlu.predict(user, context=[x[1] for x in self.history[:-1]])
        self.input_action = deepcopy(self.input_action)
        for act in self.input_action:
            intent, domain, slot, value = act
            if intent == "Request":
                user_action.append(act)
                if not self.state["request_state"].get(domain):
                    self.state["request_state"][domain] = {}
                if slot not in self.state["request_state"][domain]:
                    self.state['request_state'][domain][slot] = 0
        
        context = " ".join([utterance[1] for utterance in self.history])
        context = context[-MAX_CONTEXT_LENGTH:]
        context = self.tokenizer.encode(context)
        context = torch.tensor(context, dtype=torch.int64).unsqueeze(dim=0).cuda()  # [1, len]

        belief_gen = self.dst_(None, context, 0, test=True)[0]  # [slots, len]
        for slot_idx, domain_slot in enumerate(ontology.all_info_slots):
            domain, slot = domain_slot.split("-")
            slot = self.slot_mapping.get(slot, slot)
            value = belief_gen[slot_idx][:-1]  # remove <EOS>
            value = self.tokenizer.decode(value)
            if value != "none":
                if slot in self.state["belief_state"][domain]["book"].keys():
                    if self.state["belief_state"][domain]["book"][slot] == "":
                        action = ["Inform", domain.capitalize(), REF_USR_DA[domain].get(slot, slot), value]
                        user_action.append(action)
                    self.state["belief_state"][domain]["book"][slot] = value
                elif slot in self.state["belief_state"][domain]["semi"].keys():
                    if self.state["belief_state"][domain]["semi"][slot] == "":
                        action = ["Inform", domain.capitalize(), REF_USR_DA[domain].get(slot, slot), value]
                        user_action.append(action)
                    self.state["belief_state"][domain]["semi"][slot] = value

        self.state["user_action"] = user_action

        self.output_action = deepcopy(self.policy.predict(self.state))
        model_response = self.nlg.generate(self.output_action)
        self.history.append(["sys", model_response])

        return model_response
    "police":[],
    "hospital":["postcode"]
}

with torch.no_grad():
    for batch_idx, batch in t:
        inputs, contexts, context_lengths, dial_ids = reader.make_input(batch)
        batch_size = len(contexts[0])
        turns = len(inputs)

        nlu.init_session()
        dst.init_session()

        for turn_idx in range(turns):
            context_len = contexts[turn_idx].size(1)
            input_action = nlu.predict(inputs[turn_idx]["usr"][0], inputs[turn_idx]["context"][0])
            dst.state['user_action'] = input_action
            state = dst.update(input_action)
            belief = state["belief_state"]
            belief_label = inputs[turn_idx]["belief"][0]

            # joint_acc_ = 1
            # for slot_idx, value in enumerate(belief_label):
            #     slot = ontology.all_info_slots[slot_idx]
            #     domain, slot = slot.split("-")
            #     slot = mapping.get(slot, slot)
            #     value = value[value != 0].tolist()
            #     value = vocab.decode(value[:-1])
            #     if value == "none":
            #         value = ""
            #     elif value == "don't care":