def fine_tune(pos_action, neg_action, tokenizer, model): nlg_usr = TemplateNLG(is_user=True) nlg_sys = TemplateNLG(is_user=False) pos_train_usr_utter = [] pos_train_sys_utter = [] neg_train_usr_utter = [] neg_train_sys_utter = [] for turn in pos_action: if turn[0] != [] and turn[1] != []: s_u = nlg_usr.generate(turn[0]) s_a = nlg_sys.generate(turn[1]) pos_train_usr_utter.append(s_u) pos_train_sys_utter.append(s_a) for turn in neg_action: if turn[0] != [] and turn[1] != []: s_u = nlg_usr.generate(turn[0]) s_a = nlg_sys.generate(turn[1]) neg_train_usr_utter.append(s_u) neg_train_sys_utter.append(s_a) train_usr_utter = pos_train_usr_utter + neg_train_usr_utter train_sys_utter = pos_train_sys_utter + neg_train_sys_utter train_encoding = tokenizer(train_usr_utter, train_sys_utter, padding=True, truncation=True, max_length=80) train_encoding['label'] = [1] * len(pos_train_usr_utter) + [0] * len( neg_train_usr_utter) train_dataset = Dataset.from_dict(train_encoding) train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label']) save_dir = os.path.join(root_dir, 'convlab2/policy/dqn/NLE/save/script_fine_tune') log_dir = os.path.join( root_dir, 'convlab2/policy/dqn/NLE/save/script_fine_tune/logs') training_args = TrainingArguments( output_dir=save_dir, num_train_epochs=2, per_device_train_batch_size=32, per_device_eval_batch_size=128, warmup_steps=500, weight_decay=0.01, evaluate_during_training=False, logging_dir=log_dir, ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, ) trainer.train() trainer.save_model(os.path.join(save_dir, 'fine_tune_checkpoint'))
class Dialog(Agent): def __init__(self, model_file=DEFAULT_MODEL_URL, name="Dialog"): super(Dialog, self).__init__(name=name) if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data')): os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/data')) ### download multiwoz data print('down load data from', DEFAULT_ARCHIVE_FILE_URL) if not os.path.exists(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save')): os.mkdir(os.path.join(DEFAULT_DIRECTORY,'multiwoz/save')) ### download trained model print('down load model from', DEFAULT_MODEL_URL) model_path = "" config = Config() parser = config.parser config = parser.parse_args() with open("assets/never_split.txt") as f: never_split = f.read().split("\n") self.tokenizer = BertTokenizer("assets/vocab.txt", never_split=never_split) self.nlu = BERTNLU() self.dst_ = DST(config).cuda() ckpt = torch.load("save/model_Sun_Jun_21_07:08:48_2020.pt", map_location = lambda storage, loc: storage.cuda(local_rank)) self.dst_.load_state_dict(ckpt["model"]) self.dst_.eval() self.policy = RulePolicy() self.nlg = TemplateNLG(is_user=False) self.init_session() self.slot_mapping = { "leave": "leaveAt", "arrive": "arriveBy" } def init_session(self): self.nlu.init_session() self.policy.init_session() self.nlg.init_session() self.history = [] self.state = default_state() pass def response(self, user): self.history.append(["user", user]) user_action = [] self.input_action = self.nlu.predict(user, context=[x[1] for x in self.history[:-1]]) self.input_action = deepcopy(self.input_action) for act in self.input_action: intent, domain, slot, value = act if intent == "Request": user_action.append(act) if not self.state["request_state"].get(domain): self.state["request_state"][domain] = {} if slot not in self.state["request_state"][domain]: self.state['request_state'][domain][slot] = 0 context = " ".join([utterance[1] for utterance in self.history]) context = context[-MAX_CONTEXT_LENGTH:] context = self.tokenizer.encode(context) context = torch.tensor(context, dtype=torch.int64).unsqueeze(dim=0).cuda() # [1, len] belief_gen = self.dst_(None, context, 0, test=True)[0] # [slots, len] for slot_idx, domain_slot in enumerate(ontology.all_info_slots): domain, slot = domain_slot.split("-") slot = self.slot_mapping.get(slot, slot) value = belief_gen[slot_idx][:-1] # remove <EOS> value = self.tokenizer.decode(value) if value != "none": if slot in self.state["belief_state"][domain]["book"].keys(): if self.state["belief_state"][domain]["book"][slot] == "": action = ["Inform", domain.capitalize(), REF_USR_DA[domain].get(slot, slot), value] user_action.append(action) self.state["belief_state"][domain]["book"][slot] = value elif slot in self.state["belief_state"][domain]["semi"].keys(): if self.state["belief_state"][domain]["semi"][slot] == "": action = ["Inform", domain.capitalize(), REF_USR_DA[domain].get(slot, slot), value] user_action.append(action) self.state["belief_state"][domain]["semi"][slot] = value self.state["user_action"] = user_action self.output_action = deepcopy(self.policy.predict(self.state)) model_response = self.nlg.generate(self.output_action) self.history.append(["sys", model_response]) return model_response