class Dialog(Agent): def __init__(self, model_file=DEFAULT_MODEL_URL, name="Dialog"): super(Dialog, self).__init__(name=name) if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data')): os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data')) ### download multiwoz data print('down load data from', DEFAULT_ARCHIVE_FILE_URL) if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save')): os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save')) ### download trained model print('down load model from', DEFAULT_MODEL_URL) model_path = "" config = Config() parser = config.parser config = parser.parse_args() with open("assets/never_split.txt") as f: never_split = f.read().split("\n") self.tokenizer = BertTokenizer("assets/vocab.txt", never_split=never_split) self.nlu = BERTNLU() self.dst_ = DST(config).cuda() ckpt = torch.load("save/model_Sun_Jun_21_07:08:48_2020.pt", map_location = lambda storage, loc: storage.cuda(local_rank)) self.dst_.load_state_dict(ckpt["model"]) self.dst_.eval() self.policy = RulePolicy() self.nlg = TemplateNLG(is_user=False) self.init_session() self.slot_mapping = { "leave": "leaveAt", "arrive": "arriveBy" } def init_session(self): self.nlu.init_session() self.policy.init_session() self.nlg.init_session() self.history = [] self.state = default_state() pass def response(self, user): self.history.append(["user", user]) user_action = [] self.input_action = self.nlu.predict(user, context=[x[1] for x in self.history[:-1]]) self.input_action = deepcopy(self.input_action) for act in self.input_action: intent, domain, slot, value = act if intent == "Request": user_action.append(act) if not self.state["request_state"].get(domain): self.state["request_state"][domain] = {} if slot not in self.state["request_state"][domain]: self.state['request_state'][domain][slot] = 0 context = " ".join([utterance[1] for utterance in self.history]) context = context[-MAX_CONTEXT_LENGTH:] context = self.tokenizer.encode(context) context = torch.tensor(context, dtype=torch.int64).unsqueeze(dim=0).cuda() # [1, len] belief_gen = self.dst_(None, context, 0, test=True)[0] # [slots, len] for slot_idx, domain_slot in enumerate(ontology.all_info_slots): domain, slot = domain_slot.split("-") slot = self.slot_mapping.get(slot, slot) value = belief_gen[slot_idx][:-1] # remove <EOS> value = self.tokenizer.decode(value) if value != "none": if slot in self.state["belief_state"][domain]["book"].keys(): if self.state["belief_state"][domain]["book"][slot] == "": action = ["Inform", domain.capitalize(), REF_USR_DA[domain].get(slot, slot), value] user_action.append(action) self.state["belief_state"][domain]["book"][slot] = value elif slot in self.state["belief_state"][domain]["semi"].keys(): if self.state["belief_state"][domain]["semi"][slot] == "": action = ["Inform", domain.capitalize(), REF_USR_DA[domain].get(slot, slot), value] user_action.append(action) self.state["belief_state"][domain]["semi"][slot] = value self.state["user_action"] = user_action self.output_action = deepcopy(self.policy.predict(self.state)) model_response = self.nlg.generate(self.output_action) self.history.append(["sys", model_response]) return model_response
"taxi": ["car", "arriveBy", "destination", "departure", "leaveAt"], "hotel": ["parking", "internet", "postcode", "phone", "address", "Ref", "stars", "type", "area", "pricerange"], "train": ["Ref", "leaveAt", "duration", "price", "arriveBy", "people", "trainID", "destination", "departure", "day"], "attraction": ["address", "postcode", "price", "phone", "area", "type"], "restaurant": ["address", "Ref", "area", "postcode", "food", "phone", "pricerange", "name"], "police":[], "hospital":["postcode"] } with torch.no_grad(): for batch_idx, batch in t: inputs, contexts, context_lengths, dial_ids = reader.make_input(batch) batch_size = len(contexts[0]) turns = len(inputs) nlu.init_session() dst.init_session() for turn_idx in range(turns): context_len = contexts[turn_idx].size(1) input_action = nlu.predict(inputs[turn_idx]["usr"][0], inputs[turn_idx]["context"][0]) dst.state['user_action'] = input_action state = dst.update(input_action) belief = state["belief_state"] belief_label = inputs[turn_idx]["belief"][0] # joint_acc_ = 1 # for slot_idx, value in enumerate(belief_label): # slot = ontology.all_info_slots[slot_idx] # domain, slot = slot.split("-") # slot = mapping.get(slot, slot)