} } # # user_goal = goal goal = Goal(goal_generator) goal.set_user_goal(user_goal) # # user_policy.init_session(ini_goal=goal) # sys_policy.init_session() # # goal = user_policy.get_goal() # # pprint(goal) sys_response = '' # sess.init_session(ini_goal=goal) user_policy.init_session(ini_goal=goal) print('init goal:') # pprint(user_policy.get_goal()) pprint(user_agent.policy.get_goal()) # pprint(sess.evaluator.goal) # print('-' * 50) # for i in range(20): # sys_response, user_response, session_over, reward = sess.next_turn(sys_response) # print('user:'******'sys:', sys_response) # print() # if session_over is True: # break # print('task success:', sess.evaluator.task_success()) # print('book rate:', sess.evaluator.book_rate()) # print('inform precision/recall/f1:', sess.evaluator.inform_F1())
class Dialog(Agent): def __init__(self, model_file=DEFAULT_MODEL_URL, name="Dialog"): super(Dialog, self).__init__(name=name) if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data')): os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data')) ### download multiwoz data print('down load data from', DEFAULT_ARCHIVE_FILE_URL) if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save')): os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save')) ### download trained model print('down load model from', DEFAULT_MODEL_URL) model_path = "" config = Config() parser = config.parser config = parser.parse_args() with open("assets/never_split.txt") as f: never_split = f.read().split("\n") self.tokenizer = BertTokenizer("assets/vocab.txt", never_split=never_split) self.nlu = BERTNLU() self.dst_ = DST(config).cuda() ckpt = torch.load("save/model_Sun_Jun_21_07:08:48_2020.pt", map_location = lambda storage, loc: storage.cuda(local_rank)) self.dst_.load_state_dict(ckpt["model"]) self.dst_.eval() self.policy = RulePolicy() self.nlg = TemplateNLG(is_user=False) self.init_session() self.slot_mapping = { "leave": "leaveAt", "arrive": "arriveBy" } def init_session(self): self.nlu.init_session() self.policy.init_session() self.nlg.init_session() self.history = [] self.state = default_state() pass def response(self, user): self.history.append(["user", user]) user_action = [] self.input_action = self.nlu.predict(user, context=[x[1] for x in self.history[:-1]]) self.input_action = deepcopy(self.input_action) for act in self.input_action: intent, domain, slot, value = act if intent == "Request": user_action.append(act) if not self.state["request_state"].get(domain): self.state["request_state"][domain] = {} if slot not in self.state["request_state"][domain]: self.state['request_state'][domain][slot] = 0 context = " ".join([utterance[1] for utterance in self.history]) context = context[-MAX_CONTEXT_LENGTH:] context = self.tokenizer.encode(context) context = torch.tensor(context, dtype=torch.int64).unsqueeze(dim=0).cuda() # [1, len] belief_gen = self.dst_(None, context, 0, test=True)[0] # [slots, len] for slot_idx, domain_slot in enumerate(ontology.all_info_slots): domain, slot = domain_slot.split("-") slot = self.slot_mapping.get(slot, slot) value = belief_gen[slot_idx][:-1] # remove <EOS> value = self.tokenizer.decode(value) if value != "none": if slot in self.state["belief_state"][domain]["book"].keys(): if self.state["belief_state"][domain]["book"][slot] == "": action = ["Inform", domain.capitalize(), REF_USR_DA[domain].get(slot, slot), value] user_action.append(action) self.state["belief_state"][domain]["book"][slot] = value elif slot in self.state["belief_state"][domain]["semi"].keys(): if self.state["belief_state"][domain]["semi"][slot] == "": action = ["Inform", domain.capitalize(), REF_USR_DA[domain].get(slot, slot), value] user_action.append(action) self.state["belief_state"][domain]["semi"][slot] = value self.state["user_action"] = user_action self.output_action = deepcopy(self.policy.predict(self.state)) model_response = self.nlg.generate(self.output_action) self.history.append(["sys", model_response]) return model_response