class Dialog(Agent): def __init__(self, model_file=DEFAULT_MODEL_URL, name="Dialog"): super(Dialog, self).__init__(name=name) if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data')): os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data')) ### download multiwoz data print('down load data from', DEFAULT_ARCHIVE_FILE_URL) if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save')): os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save')) ### download trained model print('down load model from', DEFAULT_MODEL_URL) model_path = "" config = Config() parser = config.parser config = parser.parse_args() with open("assets/never_split.txt") as f: never_split = f.read().split("\n") self.tokenizer = BertTokenizer("assets/vocab.txt", never_split=never_split) self.nlu = BERTNLU() self.dst_ = DST(config).cuda() ckpt = torch.load("save/model_Sun_Jun_21_07:08:48_2020.pt", map_location = lambda storage, loc: storage.cuda(local_rank)) self.dst_.load_state_dict(ckpt["model"]) self.dst_.eval() self.policy = RulePolicy() self.nlg = TemplateNLG(is_user=False) self.init_session() self.slot_mapping = { "leave": "leaveAt", "arrive": "arriveBy" } def init_session(self): self.nlu.init_session() self.policy.init_session() self.nlg.init_session() self.history = [] self.state = default_state() pass def response(self, user): self.history.append(["user", user]) user_action = [] self.input_action = self.nlu.predict(user, context=[x[1] for x in self.history[:-1]]) self.input_action = deepcopy(self.input_action) for act in self.input_action: intent, domain, slot, value = act if intent == "Request": user_action.append(act) if not self.state["request_state"].get(domain): self.state["request_state"][domain] = {} if slot not in self.state["request_state"][domain]: self.state['request_state'][domain][slot] = 0 context = " ".join([utterance[1] for utterance in self.history]) context = context[-MAX_CONTEXT_LENGTH:] context = self.tokenizer.encode(context) context = torch.tensor(context, dtype=torch.int64).unsqueeze(dim=0).cuda() # [1, len] belief_gen = self.dst_(None, context, 0, test=True)[0] # [slots, len] for slot_idx, domain_slot in enumerate(ontology.all_info_slots): domain, slot = domain_slot.split("-") slot = self.slot_mapping.get(slot, slot) value = belief_gen[slot_idx][:-1] # remove <EOS> value = self.tokenizer.decode(value) if value != "none": if slot in self.state["belief_state"][domain]["book"].keys(): if self.state["belief_state"][domain]["book"][slot] == "": action = ["Inform", domain.capitalize(), REF_USR_DA[domain].get(slot, slot), value] user_action.append(action) self.state["belief_state"][domain]["book"][slot] = value elif slot in self.state["belief_state"][domain]["semi"].keys(): if self.state["belief_state"][domain]["semi"][slot] == "": action = ["Inform", domain.capitalize(), REF_USR_DA[domain].get(slot, slot), value] user_action.append(action) self.state["belief_state"][domain]["semi"][slot] = value self.state["user_action"] = user_action self.output_action = deepcopy(self.policy.predict(self.state)) model_response = self.nlg.generate(self.output_action) self.history.append(["sys", model_response]) return model_response
# print(user_utt) # sys_utt = sys_agent.response(user_utt) # print(sys_utt) # print(user_policy.agenda) user_act = user_policy.predict([]) print(user_act) user_utt = user_nlg.generate(user_act) print(user_utt) history.append(['user', user_utt]) state = dst.state state['user_action'] = user_act dst.update(user_act) # pprint(state) sys_act = sys_policy.predict(state) sys_utt = sys_nlg.generate(sys_act) # sys_act.append(["Request", "Restaurant", "Price", "?"]) # sys_act = [['Request', 'Hotel', 'Area', '?'], ['Request', 'Hotel', 'Stars', '?']] sys_act = [['Inform', 'Hotel', 'Post', 'pe296fl']] print(sys_act) history.append(['sys', user_utt]) # sys_utt = sys_agent.response(user_utt) # print(sys_utt) # user_act = user_policy.predict(sys_act) print(user_act) user_utt = user_nlg.generate(user_act) print(user_utt) history.append(['user', user_utt])