示例#1
0
class Dialog(Agent):
    def __init__(self, model_file=DEFAULT_MODEL_URL, name="Dialog"):
        super(Dialog, self).__init__(name=name)
        if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data')):
            os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data'))
            ### download multiwoz data
            print('down load data from', DEFAULT_ARCHIVE_FILE_URL)
        if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save')):
            os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save'))
            ### download trained model
            print('down load model from', DEFAULT_MODEL_URL)
        model_path = ""

        config = Config()
        parser = config.parser
        config = parser.parse_args()
        with open("assets/never_split.txt") as f:
            never_split = f.read().split("\n")
        self.tokenizer = BertTokenizer("assets/vocab.txt", never_split=never_split)
        self.nlu = BERTNLU()
        self.dst_ = DST(config).cuda()
        ckpt = torch.load("save/model_Sun_Jun_21_07:08:48_2020.pt", map_location = lambda storage, loc: storage.cuda(local_rank))
        self.dst_.load_state_dict(ckpt["model"])
        self.dst_.eval()
        self.policy = RulePolicy()
        self.nlg = TemplateNLG(is_user=False)
        self.init_session()
        self.slot_mapping = {
            "leave": "leaveAt",
            "arrive": "arriveBy"
        }

    def init_session(self):
        self.nlu.init_session()
        self.policy.init_session()
        self.nlg.init_session()
        self.history = []
        self.state = default_state()
        pass

    def response(self, user):
        self.history.append(["user", user])

        user_action = []

        self.input_action = self.nlu.predict(user, context=[x[1] for x in self.history[:-1]])
        self.input_action = deepcopy(self.input_action)
        for act in self.input_action:
            intent, domain, slot, value = act
            if intent == "Request":
                user_action.append(act)
                if not self.state["request_state"].get(domain):
                    self.state["request_state"][domain] = {}
                if slot not in self.state["request_state"][domain]:
                    self.state['request_state'][domain][slot] = 0
        
        context = " ".join([utterance[1] for utterance in self.history])
        context = context[-MAX_CONTEXT_LENGTH:]
        context = self.tokenizer.encode(context)
        context = torch.tensor(context, dtype=torch.int64).unsqueeze(dim=0).cuda()  # [1, len]

        belief_gen = self.dst_(None, context, 0, test=True)[0]  # [slots, len]
        for slot_idx, domain_slot in enumerate(ontology.all_info_slots):
            domain, slot = domain_slot.split("-")
            slot = self.slot_mapping.get(slot, slot)
            value = belief_gen[slot_idx][:-1]  # remove <EOS>
            value = self.tokenizer.decode(value)
            if value != "none":
                if slot in self.state["belief_state"][domain]["book"].keys():
                    if self.state["belief_state"][domain]["book"][slot] == "":
                        action = ["Inform", domain.capitalize(), REF_USR_DA[domain].get(slot, slot), value]
                        user_action.append(action)
                    self.state["belief_state"][domain]["book"][slot] = value
                elif slot in self.state["belief_state"][domain]["semi"].keys():
                    if self.state["belief_state"][domain]["semi"][slot] == "":
                        action = ["Inform", domain.capitalize(), REF_USR_DA[domain].get(slot, slot), value]
                        user_action.append(action)
                    self.state["belief_state"][domain]["semi"][slot] = value

        self.state["user_action"] = user_action

        self.output_action = deepcopy(self.policy.predict(self.state))
        model_response = self.nlg.generate(self.output_action)
        self.history.append(["sys", model_response])

        return model_response
    "taxi": ["car", "arriveBy", "destination", "departure", "leaveAt"],
    "hotel": ["parking", "internet", "postcode", "phone", "address", "Ref", "stars", "type", "area", "pricerange"],
    "train": ["Ref", "leaveAt", "duration", "price", "arriveBy", "people", "trainID", "destination", "departure", "day"],
    "attraction": ["address", "postcode", "price", "phone", "area", "type"],
    "restaurant": ["address", "Ref", "area", "postcode", "food", "phone", "pricerange", "name"],
    "police":[],
    "hospital":["postcode"]
}

with torch.no_grad():
    for batch_idx, batch in t:
        inputs, contexts, context_lengths, dial_ids = reader.make_input(batch)
        batch_size = len(contexts[0])
        turns = len(inputs)

        nlu.init_session()
        dst.init_session()

        for turn_idx in range(turns):
            context_len = contexts[turn_idx].size(1)
            input_action = nlu.predict(inputs[turn_idx]["usr"][0], inputs[turn_idx]["context"][0])
            dst.state['user_action'] = input_action
            state = dst.update(input_action)
            belief = state["belief_state"]
            belief_label = inputs[turn_idx]["belief"][0]

            # joint_acc_ = 1
            # for slot_idx, value in enumerate(belief_label):
            #     slot = ontology.all_info_slots[slot_idx]
            #     domain, slot = slot.split("-")
            #     slot = mapping.get(slot, slot)