def get_delex_dict(data_path): delex_dict = {} train_time_dict = {} for domain in ontology.all_domains: if domain == "taxi": for car in ["toyota", "skoda", "bmw", "honda", "ford", "audi", "lexus", "volvo", "volkswagen", "tesla"]: delex_dict[" {} ".format(car)] = " [taxi_car] " else: db = json.load(open(os.path.join(data_path, "{}_db.json".format(domain)), "r")) for idx, entry in enumerate(db): for slot, value in entry.items(): if slot == "location" or slot == "price" and domain == "hotel": continue slot, value = clean_data.clean_slot_values(domain, slot, value) if slot == "price": if value in ["?", "free"]: continue delex_dict[" {} ".format(value)] = " [price] " elif slot in ["address", "food", "postcode", "duration", "name"]: delex_dict[" {} ".format(value)] = " [{}_{}] ".format(domain, slot) elif slot == "stars": rules = ["{} star", "{}-star", "rating of {}", "star of {}"] for rule in rules: delex_dict[rule.format(value)] = rule.format("[hotel_stars]") elif slot == "id" and domain == "train": delex_dict[" {} ".format(value)] = " [train_id] " elif slot in ["arrive", "leave"]: train_time_dict[" {} ".format(value)] = " [train_time] " # replace to arrive or leave after delexicalization during creating data elif slot == "pricerange": if value in ["?", "free"]: continue delex_dict[" {} ".format(value)] = " [pricerange] " elif slot == "type" and domain in ["hotel", "attraction"]: delex_dict[" {} ".format(value)] = " [{}_type] ".format(domain) elif slot == "area": delex_dict[" {} ".format(value)] = " [area] " elif slot == "phone": delex_dict[" {} ".format(value)] = " [phone] " elif slot == "department": delex_dict[" {} ".format(value)] = " [hospital_phone] " reference_nums = get_references(data_path) for value in reference_nums: domain, num = value.split("-") delex_dict[" {} ".format(num)] = " [{}_reference] ".format(domain) return delex_dict, train_time_dict
def parse_goal(self, dial_id): inform_goal = {} request_goal = {} book_goal = {} if dial_id in self.dev_list: goal = self.dev[dial_id]["goal"] elif dial_id in self.test_list: goal = self.test[dial_id]["goal"] elif dial_id in self.train_list: goal = self.train[dial_id]["goal"] for domain, goals in goal.items(): for category, slots in goals.items(): if category == "info": inform_goal[domain] = {} for slot, value in slots.items(): slot, value = clean_slot_values(domain, slot, value) inform_goal[domain][slot] = value if inform_goal[domain] == {}: inform_goal.pop(domain) elif category == "reqt": request_goal[domain] = [] for slot in slots: slot = slot.lower() if ontology.normlize_slot_names.get(slot): slot = ontology.normlize_slot_names[slot] if slot in self.requestables: # evaluate only 5 request slots request_goal[domain].append(slot) if len(request_goal[domain]) == 0: request_goal.pop(domain) elif category == "book": if request_goal.get(domain): request_goal[domain].append("reference") else: request_goal[domain] = ["reference"] # book_goal[domain] = {} # for slot, value in slots.items(): # if slot in ["pre_invalid", "invalid"]: # continue # slot, value = clean_slot_values(domain, slot, value) # book_goal[domain][slot] = value # if book_goal[domain] == {}: # book_goal.pop(domain) return inform_goal, request_goal
def create_data(self): data = {} train = {} dev = {} test = {} ignore_list = ["SNG1213", "PMUL0382", "PMUL0237"] logger.info("Processing data...") for dial_id, dial in tqdm(self.data.items()): dial_id = dial_id.split(".")[0] if dial_id in ignore_list: continue dialogue = {} goal = {} dial_domains = [] for key, value in dial["goal"].items(): # process user's goal if key in ontology.all_domains and value != {}: if value.get("reqt"): # normalize requestable slot names for idx, slot in enumerate(value["reqt"]): if ontology.normlize_slot_names.get(slot): value["reqt"][ idx] = ontology.normlize_slot_names[slot] goal[key] = value dial_domains.append(key) if len(dial_domains) == 0: # ignore police and hospital ignore_list.append(dial_id) continue dialogue["goal"] = goal dialogue["log"] = [] acts = self.acts[dial_id] turn = {} for turn_num, turn_dial in enumerate(dial["log"]): meta_data = turn_dial["metadata"] if meta_data == {}: # user turn turn["turn_num"] = int(turn_num / 2) turn["user"] = clean_text(turn_dial["text"]) else: # system turn turn["response"] = clean_text(turn_dial["text"]) response_ = clean_text(turn_dial["text"]) for k, v in self.delex_dict.items(): response_ = response_.replace(k, v) # delexicalize values is_train = False for token in [ "train", "arrive", "arrives", "arrived", "arriving", "arrival", "destination", "reach", "leave", "leaves", "leaving", "leaved", "depart", "departing", "departs", "departure", "[train_" ]: if token in response_: is_train = True break if is_train: for k, v in self.train_time_dict.items(): response_ = response_.replace( k, v) # delexicalize train times response_ = re.sub("(\d\s?){11}", "[phone]", response_) # delexicalize phone number while response_.find( "[train_time]" ) != -1: # replace [train_time] to [train_arrive] or [train_leave] by rule response_split = response_.split() idx = response_split.index("[train_time]") replaced = False for token in response_split[:idx][::-1]: if token in [ "arrive", "arrives", "arrived", "arriving", "arrival", "destination", "reach", "by", "before", "have", "to" ]: response_split[idx] = "[train_arrive]" replaced = True break elif token in [ "leave", "leaves", "leaving", "leaved", "depart", "departing", "departs", "departure", "from", "after", "earlier", "there" ]: response_split[idx] = "[train_leave]" replaced = True break if not replaced: response_split[idx] = "[train_leave]" response_ = " ".join(response_split) turn["response_delex"] = response_ belief = {} gate = {} act = {} for domain in dial_domains: # active domains of dialogue for slot, value in meta_data[domain]["book"].items( ): # book if slot == "booked": continue slot, value = clean_slot_values( domain, slot, value) if value != "": belief["{}-{}".format(domain, slot)] = value gate["{}-{}".format( domain, slot )] = ontology.gate_dict[ value] if value == "dontcare" else ontology.gate_dict[ "prediction"] for slot, value in meta_data[domain]["semi"].items( ): # semi slot, value = clean_slot_values( domain, slot, value) if value != "": belief["{}-{}".format(domain, slot)] = value gate["{}-{}".format( domain, slot )] = ontology.gate_dict[ value] if value == "dontcare" else ontology.gate_dict[ "prediction"] turn["belief"] = belief turn["gate"] = gate if acts.get(str(turn["turn_num"] + 1)) and type( acts.get(str(turn["turn_num"] + 1))) != str: # mapping system action for domain_act, slots in acts[str(turn["turn_num"] + 1)].items(): act_temp = [] for slot in slots: # slot: [slot, value] slot_, value_ = clean_slot_values( domain_act.split("-")[0], slot[0], slot[1]) if slot_ == "none" or value_ in [ "?", "none" ]: # general domain or request slot or parking act_temp.append(slot_) else: act_temp.append("{}-{}".format( slot_, value_)) act[domain_act.lower()] = act_temp turn["action"] = act dialogue["log"].append(turn) turn = {} # clear turn data[dial_id] = dialogue logger.info("Processing finished.") logger.info("Dividing data to train/dev/test...") for dial_id in self.train_list: dial_id = dial_id.split(".")[0] if dial_id not in ignore_list: train[dial_id] = data[dial_id] for dial_id in self.dev_list: dial_id = dial_id.split(".")[0] if dial_id not in ignore_list: dev[dial_id] = data[dial_id] for dial_id in self.test_list: dial_id = dial_id.split(".")[0] if dial_id not in ignore_list: test[dial_id] = data[dial_id] logger.info("Dividing finished.") value_ontology = json.load( open(os.path.join(self.data_path, "ontology.json"), "r")) value_ontology_processed = {} logger.info("Processing ontology...") for domain_slot, values in value_ontology.items(): domain = domain_slot.split("-")[0] slot = domain_slot.split("-")[2].lower() if ontology.normlize_slot_names.get(slot): slot = ontology.normlize_slot_names[slot] domain_slot = "-".join([domain, slot]) value_ontology_processed[domain_slot] = [] for value in values: _, value = clean_slot_values(domain, slot, value) value_ontology_processed[domain_slot].append(value) with open(os.path.join(data_path, "ontology_processed.json"), "w") as f: json.dump(value_ontology_processed, f, indent=2) logger.info("Ontology was processed.") return train, dev, test
def train_forward(self, turn_inputs, turn_contexts, turn_context_lengths, action_history, teacher_forcing): batch_size = turn_contexts.size(0) user = turn_inputs["usr"] history = turn_inputs["context"] for batch_idx in range(batch_size): self.states[batch_idx]["history"].append(["usr", turn_inputs["usr"][batch_idx]]) input_action = self.nlu.predict(user[batch_idx], history[batch_idx]) # [[intent, domain, slot, value], ..., [intent, domain, slot, value]] self.states[batch_idx]["user_action"] = input_action action_history[batch_idx] += input_action # update request state for action in input_action: intent, domain, slot, _ = action slot, _ = clean_slot_values(domain, slot, _) domain = domain.lower() if intent == "Request": if not self.states[batch_idx]["request_state"].get(domain): self.states[batch_idx]["request_state"][domain] = {} self.states[batch_idx]["request_state"][domain][slot] = 0 # calculate score of input action generated by BERT NLU encoded_action_history = [] history_lengths = [] for batch_idx, actions in enumerate(action_history): encoded_actions = [] for action in actions: intent, domain, slot, value = action slot = ontology.normlize_slot_names.get(slot, slot) encoded_actions += self.vocab.encode(" ".join([intent, domain, slot, value]))[1:-1] encoded_actions = [self.vocab.word2idx["<bos>"]] + encoded_actions + [self.vocab.word2idx["<eos>"]] history_lengths.append(len(encoded_actions)) encoded_action_history.append(encoded_actions) history_lengths = torch.tensor(history_lengths).cuda() encoded_action_history_ = torch.zeros((batch_size, history_lengths.max().item()), dtype=torch.int64).cuda() for batch_idx, actions in enumerate(encoded_action_history): encoded_action_history_[batch_idx, :history_lengths[batch_idx]] = torch.tensor(actions) action_history_outputs, action_history_hidden = self.context_encoder(encoded_action_history_, history_lengths) # [batch, len, hidden], [layers, batch, hidden] encoder_outputs_, encoder_hidden_ = self.context_encoder(turn_contexts, turn_context_lengths) # [batch, len, hidden], [layers, batch, hidden] action_score = torch.sigmoid(torch.matmul(action_history_hidden[-1].unsqueeze(dim=1), encoder_hidden_[-1].unsqueeze(dim=2))) # [batch, 1, 1] # action score for attention score action_score_attenion = torch.zeros(batch_size, (turn_context_lengths+history_lengths).max().item()).cuda() # [batch, len] for batch_idx, context_len in enumerate(turn_context_lengths): action_score_attenion[batch_idx, :context_len] = 1-action_score.squeeze()[batch_idx] action_score_attenion[batch_idx, context_len:context_len+history_lengths[batch_idx].item()] = action_score.squeeze()[batch_idx] # weighted sum & weighted concat encoder_outputs, turn_contexts, turn_context_lengths = self.concat(encoder_outputs_, action_history_outputs, turn_contexts, turn_context_lengths, \ encoded_action_history_, history_lengths) action_score = action_score.transpose(0,1).contiguous() encoder_hidden = (1 - action_score) * encoder_hidden_ + action_score * action_history_hidden # [layers, batch, hidden] gate_outputs, all_probs, all_pred_words = self.belief_decoder(encoder_outputs, encoder_hidden, turn_contexts, turn_context_lengths, \ turn_inputs["belief"], teacher_forcing, action_score_attenion) # [batch, slots, 3], [batch, slots, len, vocab], [batch, slots, len], [batch, slots] gate_preds = gate_outputs.argmax(dim=2) # prev_gate = turn_inputs.get("prev_gate") max_value_len = 0 belief_gen = [] # [batch, slots, len] for batch_idx, batch in enumerate(all_pred_words): belief_gen_ = [] # [slots, len] for slot_idx, pred_words in enumerate(batch): if gate_preds[batch_idx, slot_idx].item() == ontology.gate_dict["none"]: belief_gen_.append(self.vocab.encode("none")[1:]) len_ = len(self.vocab.encode("none")[1:]) elif gate_preds[batch_idx, slot_idx].item() == ontology.gate_dict["don't care"]: belief_gen_.append(self.vocab.encode("don't care")[1:]) len_ = len(self.vocab.encode("don't care")[1:]) else: for idx, value in enumerate(pred_words): if value == self.vocab.word2idx["<eos>"]: break belief_gen_.append(pred_words[:idx+1].tolist()) len_ = idx + 1 max_value_len = max(max_value_len, len_) belief_gen.append(belief_gen_) gate_label = turn_inputs["gate"] gate_loss = F.cross_entropy(gate_outputs.view(-1, 3), gate_label.view(-1)) # if prev_gate is not None: # turn_domain = self.make_turn_domain(prev_gate, gate_preds) # [batch] # else: # turn_domain = self.make_turn_domain(None, gate_preds) # prev_gate = gate_preds.detach() # [batch, slots] acc_belief = torch.ones(batch_size, len(ontology.all_info_slots)).cuda() gate_mask = (gate_label != gate_preds) acc_belief.masked_fill_(gate_mask, value=0) # fail to predict gate value_label = turn_inputs["belief"] value_label_lengths = torch.zeros(batch_size, len(ontology.all_info_slots), dtype=torch.int64).cuda() for batch_idx, batch in enumerate(value_label): for slot_idx, pred_words in enumerate(batch): value = pred_words[pred_words != self.vocab.word2idx["<pad>"]].tolist() # remove padding value_label_lengths[batch_idx, slot_idx] = len(value) if value != belief_gen[batch_idx][slot_idx]: acc_belief[batch_idx, slot_idx] = 0 # fail to predict value if teacher_forcing: value_loss = masked_cross_entropy_for_value(all_probs, value_label, value_label_lengths) else: min_len = min(value_label.size(2), all_probs.size(2)) value_loss = masked_cross_entropy_for_value(all_probs[:, :, :min_len, :].contiguous(), value_label[:, :, :min_len].contiguous(), value_label_lengths) ### make state for Rule policy # for batch_idx, belief in enumerate(belief_gen): # for slot_idx, slot in enumerate(belief): # value = self.vocab.decode(slot[:-1]) # domain, slot = ontology.all_info_slots[slot_idx].split("-") # # book slots # for slot_ in self.states[batch_idx]["belief_state"][domain]["book"].keys(): # if slot_ == "booked": # continue # slot_ = ontology.normlize_slot_names.get(slot_, slot_) # if slot_ == slot: # self.states[batch_idx]["belief_state"][domain]["book"][slot_] = value # # semi slots # for slot_ in self.states[batch_idx]["belief_state"][domain]["semi"].keys(): # slot_ = ontology.normlize_slot_names.get(slot_, slot_) # if slot_ == slot: # self.states[batch_idx]["belief_state"][domain]["semi"][slot_] = value # ### policy # output_actions = [] # for batch_idx, state in enumerate(self.states): # output_actions.append(self.policy.predict(state)) # ### NLG # model_responses = [] # for batch_idx, output_action in enumerate(output_actions): # model_responses.append(self.nlg.generate(output_actions[batch_idx])) return gate_loss, value_loss, acc_belief, belief_gen, action_history
def create_data(self): data = {} train = {} dev = {} test = {} ignore_list = ["SNG1213", "PMUL0382", "PMUL0237"] logger.info("Processing data...") for dial_id, dial in tqdm(self.data.items()): dial_id = dial_id.split(".")[0] if dial_id in ignore_list: continue dialogue = {} goal = {} dial_domains = [] for key, value in dial["goal"].items(): # process user's goal if key in ontology.all_domains and value != {}: if value.get("reqt"): # normalize requestable slot names for idx, slot in enumerate(value["reqt"]): if ontology.normlize_slot_names.get(slot): value["reqt"][idx] = ontology.normlize_slot_names[slot] goal[key] = value dial_domains.append(key) if len(dial_domains) == 0: # ignore police and hospital ignore_list.append(dial_id) continue dialogue["goal"] = goal dialogue["log"] = [] acts = self.acts[dial_id] turn = {} for turn_num, turn_dial in enumerate(dial["log"]): meta_data = turn_dial["metadata"] if meta_data == {}: # user turn turn["turn_num"] = int(turn_num/2) turn["user"] = clean_text(turn_dial["text"]) else: # system turn turn["response"] = clean_text(turn_dial["text"]) belief = {} gate = {} act = {} for domain in dial_domains: # active domains of dialogue for slot, value in meta_data[domain]["book"].items(): # book if slot == "booked": continue slot, value = clean_slot_values(domain, slot, value) if value != "": belief["{}-{}".format(domain, slot)] = value gate["{}-{}".format(domain, slot)] = ontology.gate_dict[value] if value == "don't care" else ontology.gate_dict["prediction"] for slot, value in meta_data[domain]["semi"].items(): # semi slot, value = clean_slot_values(domain, slot, value) if value != "": belief["{}-{}".format(domain, slot)] = value gate["{}-{}".format(domain, slot)] = ontology.gate_dict[value] if value == "don't care" else ontology.gate_dict["prediction"] turn["belief"] = belief turn["gate"] = gate if acts.get(str(turn["turn_num"]+1)) and type(acts.get(str(turn["turn_num"]+1))) != str: # mapping system action for domain_act, slots in acts[str(turn["turn_num"]+1)].items(): act_temp = [] for slot in slots: # slot: [slot, value] slot_, value_ = clean_slot_values(domain_act.split("-")[0], slot[0], slot[1]) if slot_ == "none" or value_ in ["?", "none"]: # general domain or request slot or parking act_temp.append(slot_) else: act_temp.append("{}-{}".format(slot_, value_)) act[domain_act.lower()] = act_temp turn["action"] = act dialogue["log"].append(turn) turn = {} # clear turn data[dial_id] = dialogue logger.info("Processing finished.") logger.info("Dividing data to train/dev/test...") for dial_id in self.train_list: dial_id = dial_id.split(".")[0] if dial_id not in ignore_list: train[dial_id] = data[dial_id] for dial_id in self.dev_list: dial_id = dial_id.split(".")[0] if dial_id not in ignore_list: dev[dial_id] = data[dial_id] for dial_id in self.test_list: dial_id = dial_id.split(".")[0] if dial_id not in ignore_list: test[dial_id] = data[dial_id] logger.info("Dividing finished.") value_ontology = json.load(open(os.path.join(self.data_path, "ontology.json"), "r")) value_ontology_processed = {} logger.info("Processing ontology...") for domain_slot, values in value_ontology.items(): domain = domain_slot.split("-")[0] slot = domain_slot.split("-")[2].lower() if ontology.normlize_slot_names.get(slot): slot = ontology.normlize_slot_names[slot] domain_slot = "-".join([domain, slot]) value_ontology_processed[domain_slot] = [] for value in values: _, value = clean_slot_values(domain, slot, value) value_ontology_processed[domain_slot].append(value) with open(os.path.join(data_path, "ontology_processed.json"), "w") as f: json.dump(value_ontology_processed, f, indent=2) logger.info("Ontology was processed.") return train, dev, test