def main(): model_config_name = 'policy/mle/train.json' common_config_name = 'policy/mle/common.json' data_urls = { 'sys_da_voc.json': 'http://qiw2jpwfc.hn-bkt.clouddn.com/usr_da_voc.json', 'usr_da_voc.json': 'http://qiw2jpwfc.hn-bkt.clouddn.com/usr_da_voc.json' } # load config root_path = get_root_path() common_config_path = os.path.join(get_config_path(), common_config_name) model_config_path = os.path.join(get_config_path(), model_config_name) common_config = json.load(open(common_config_path)) model_config = json.load(open(model_config_path)) model_config.update(common_config) model_config['n_gpus'] = torch.cuda.device_count() model_config['batch_size'] = max(1, model_config['n_gpus']) * model_config['batch_size'] model_config['device'] = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model_config['data_path'] = os.path.join(get_data_path(), 'crosswoz/policy_mle_data') model_config['raw_data_path'] = os.path.join(get_data_path(), 'crosswoz/raw') model_config['output_dir'] = os.path.join(root_path, model_config['output_dir']) if model_config['load_model_name']: model_config['model_path'] = os.path.join(model_config['output_dir'], model_config['load_model_name']) else: model_config['model_path'] = '' if not os.path.exists(model_config['data_path']): os.makedirs(model_config['data_path']) if not os.path.exists(model_config['output_dir']): os.makedirs(model_config['output_dir']) # download data for data_key, url in data_urls.items(): dst = os.path.join(model_config['data_path'], data_key) file_name = data_key.split('.')[0] model_config[file_name] = dst if not os.path.exists(dst): download_from_url(url, dst) print(f'>>> Train configs:') print('\t', model_config) set_seed(model_config['random_seed']) agent = Trainer(model_config) # 训练 if model_config['do_train']: start_epoch = 0 if not model_config['model_path'] else int(model_config['model_path'].split('-')[2]) + 1 best = float('inf') for epoch in tqdm(range(start_epoch, model_config['num_epochs']), desc='Epoch'): agent.imitating(epoch) best = agent.imit_eval(epoch, best) agent.calc_metrics()
def preprocess(mode): assert mode == "all" or mode == "usr" or mode == "sys" # path data_key = ["train", "val", "test"] data = {} for key in data_key: # read crosswoz source data from json.zip data[key] = read_zipped_json( os.path.join(get_data_path(), "crosswoz/raw", key + ".json.zip"), key + ".json", ) print("load {}, size {}".format(key, len(data[key]))) # generate train, val, tests dataset for key in data_key: sessions = [] for no, sess in data[key].items(): processed_data = OrderedDict() processed_data["sys-usr"] = sess["sys-usr"] processed_data["type"] = sess["type"] processed_data["task description"] = sess["task description"] messages = sess["messages"] processed_data["turns"] = [ OrderedDict( { "role": message["role"], "utterance": message["content"], "dialog_act": message["dialog_act"], } ) for message in messages ] sessions.append(processed_data) json.dump( sessions, open( os.path.join( get_data_path(), "crosswoz/readable_data", f"readabe_{key}_data.json", ), "w", encoding="utf-8", ), indent=2, ensure_ascii=False, sort_keys=False, ) print(os.path.join( get_data_path(), "crosswoz/readable_data", f"readabe_{key}_data.json", ))
def main(): output_dir = os.path.join(get_data_path(), "crosswoz/policy_rule_single_domain_data") input_dir = os.path.join(get_data_path(), "crosswoz/raw") if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) datasets = ["train", "val", "tests"] for dataset in tqdm(datasets): output_file_name = dataset + ".json" input_file_name = output_file_name + ".zip" input_file_path = os.path.join(input_dir, input_file_name) output_file_path = os.path.join(output_dir, output_file_name) get_single_domain_examples(input_file_path, output_file_path)
def main(): output_dir = os.path.join(get_data_path(), 'crosswoz/policy_rule_single_domain_data') input_dir = os.path.join(get_data_path(), 'crosswoz/raw') if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) datasets = ['train', 'val', 'test'] for dataset in tqdm(datasets): output_file_name = dataset + '.json' input_file_name = output_file_name + '.zip' input_file_path = os.path.join(input_dir, input_file_name) output_file_path = os.path.join(output_dir, output_file_name) get_single_domain_examples(input_file_path, output_file_path)
def merge_raw_date(data_type: str) -> None: """Merge belief state data into user turn Args: data_type: train, dev or tests """ data_path = get_data_path() output_dir = os.path.join(data_path, "crosswoz/dst_bert_data") if not os.path.exists(output_dir): os.makedirs(output_dir) dials_path = os.path.join(data_path, "crosswoz/dst_trade_data", f"{data_type}_dials.json") raw_filename = "val" if data_type == "dev" else data_type raw_path = os.path.join(data_path, "crosswoz/raw", f"{raw_filename}.json.zip") dials = json.load(open(dials_path, "r", encoding="utf8")) raw = read_zipped_json(raw_path, f"{raw_filename}.json") merge_data = {} for dial in tqdm(dials, desc=f"Merging {data_type}"): dialogue_idx = dial["dialogue_idx"] cur_raw = raw[dialogue_idx] merge_data[dialogue_idx] = cur_raw for turn_id, turn in enumerate(dial["dialogue"]): assert merge_data[dialogue_idx]["messages"][ 2 * turn_id]["role"] == "usr" merge_data[dialogue_idx]["messages"][ 2 * turn_id]["belief_state"] = turn["belief_state"] with open(os.path.join(output_dir, f"{data_type}4bert_dst.json"), "w", encoding="utf8") as f: json.dump(merge_data, f, ensure_ascii=False, indent=2)
def merge_raw_date(data_type: str) -> None: """Merge belief state data into user turn Args: data_type: train, dev or test """ data_path = get_data_path() output_dir = os.path.join(data_path, 'crosswoz/dst_bert_data') if not os.path.exists(output_dir): os.makedirs(output_dir) dials_path = os.path.join(data_path, 'crosswoz/dst_trade_data', f'{data_type}_dials.json') raw_filename = "val" if data_type == "dev" else data_type raw_path = os.path.join(data_path, 'crosswoz/raw', f'{raw_filename}.json.zip') dials = json.load(open(dials_path, 'r', encoding='utf8')) raw = read_zipped_json(raw_path, f'{raw_filename}.json') merge_data = {} for dial in tqdm(dials, desc=f'Merging {data_type}'): dialogue_idx = dial['dialogue_idx'] cur_raw = raw[dialogue_idx] merge_data[dialogue_idx] = cur_raw for turn_id, turn in enumerate(dial['dialogue']): assert merge_data[dialogue_idx]['messages'][2 * turn_id]['role'] == 'usr' merge_data[dialogue_idx]['messages'][2 * turn_id]['belief_state'] = turn['belief_state'] with open(os.path.join(output_dir, f'{data_type}4bert_dst.json'), 'w', encoding='utf8') as f: json.dump(merge_data, f, ensure_ascii=False, indent=2)
def __init__(self): # path root_path = get_root_path() config_file = os.path.join(get_config_path(), IntentWithBertPredictor.default_model_config) # load config config = json.load(open(config_file)) device = config['DEVICE'] # load intent vocabulary and dataloader intent_vocab = json.load(open(os.path.join(get_data_path(), 'crosswoz/nlu_intent_data/intent_vocab.json'), encoding='utf-8')) dataloader = Dataloader(intent_vocab=intent_vocab, pretrained_weights=config['model']['pretrained_weights']) # load best model best_model_path = os.path.join(DEFAULT_MODEL_PATH, IntentWithBertPredictor.default_model_name) if not os.path.exists(best_model_path): download_from_url(IntentWithBertPredictor.default_model_url, best_model_path) model = IntentWithBert(config['model'], device, dataloader.intent_dim) try: model.load_state_dict(torch.load(os.path.join(DEFAULT_MODEL_PATH, IntentWithBertPredictor.default_model_name), map_location='cpu')) except Exception as e: print(e) # cpu process model.to("cpu") model.eval() self.model = model self.dataloader = dataloader print(f"{best_model_path} loaded - {best_model_path}")
def clean_ontology() -> None: """Clean ontology data.""" data_path = get_data_path() ontology_path = os.path.join(data_path, 'crosswoz/dst_bert_data/ontology.json') ontology = json.load(open(ontology_path, 'r', encoding='utf8')) cleaned_ontologies = {} facility = [] seps = ['、', ',', ',', ';', '或', ';', ' '] for ds, values in tqdm(ontology.items()): if len(ds.split('-')) > 2: facility.append(ds.split('-')[-1]) continue if ds.split('-')[-1] == '酒店设施': continue cleaned_values = set() for value in values: multi_values = [value] for sep in seps: if sep in value: multi_values = value.split(sep) break for v in multi_values: v = ''.join(v.split()) cleaned_values.add(v) cleaned_ontologies['酒店-酒店设施'] = facility cleaned_ontologies[ds] = list(cleaned_values) cleaned_ontologies_path = os.path.join(data_path, 'crosswoz/dst_bert_data/cleaned_ontology.json') with open(cleaned_ontologies_path, 'w', encoding='utf8') as f: json.dump(cleaned_ontologies, f, ensure_ascii=False, indent=2)
def __init__(self): super(TradeDST, self).__init__() # load config common_config_path = os.path.join(get_config_path(), TradeDST.common_config_name) common_config = json.load(open(common_config_path)) model_config_path = os.path.join(get_config_path(), TradeDST.model_config_name) model_config = json.load(open(model_config_path)) model_config.update(common_config) self.model_config = model_config self.model_config['data_path'] = os.path.join( get_data_path(), 'crosswoz/dst_trade_data') self.model_config['n_gpus'] = 0 if self.model_config[ 'device'] == 'cpu' else torch.cuda.device_count() self.model_config['device'] = torch.device(self.model_config['device']) if model_config['load_embedding']: model_config['hidden_size'] = 300 # download data for model_key, url in TradeDST.model_urls.items(): dst = os.path.join(self.model_config['data_path'], model_key) if model_key.endswith('pth'): file_name = 'trained_model_path' elif model_key.endswith('pkl'): file_name = model_key.rsplit('-', maxsplit=1)[0] else: file_name = model_key.split('.')[0] # ontology self.model_config[file_name] = dst if not os.path.exists(dst) or not self.model_config['use_cache']: download_from_url(url, dst) # load date & model ontology = json.load( open(self.model_config['ontology'], 'r', encoding='utf8')) self.all_slots = get_slot_information(ontology) self.gate2id = {'ptr': 0, 'none': 1} self.id2gate = {id_: gate for gate, id_ in self.gate2id.items()} self.lang = pickle.load(open(self.model_config['lang'], 'rb')) self.mem_lang = pickle.load(open(self.model_config['mem-lang'], 'rb')) model = Trade( lang=self.lang, vocab_size=len(self.lang.index2word), hidden_size=self.model_config['hidden_size'], dropout=self.model_config['dropout'], num_encoder_layers=self.model_config['num_encoder_layers'], num_decoder_layers=self.model_config['num_decoder_layers'], pad_id=self.model_config['pad_id'], slots=self.all_slots, num_gates=len(self.gate2id), unk_mask=self.model_config['unk_mask']) model.load_state_dict( torch.load(self.model_config['trained_model_path'])) self.model = model.to(self.model_config['device']).eval() print(f'>>> {self.model_config["trained_model_path"]} loaded ...') self.state = default_state() print('>>> State initialized ...')
def load_config() -> dict: """Load config from common config and inference config from xbot/config/dst/bert . Returns: config dict """ root_path = get_root_path() common_config_path = os.path.join(get_config_path(), BertDST.common_config_name) infer_config_path = os.path.join(get_config_path(), BertDST.infer_config_name) common_config = json.load(open(common_config_path)) infer_config = json.load(open(infer_config_path)) infer_config.update(common_config) infer_config['device'] = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') infer_config['data_path'] = os.path.join(get_data_path(), 'crosswoz/dst_bert_data') infer_config['output_dir'] = os.path.join(root_path, infer_config['output_dir']) if not os.path.exists(infer_config['data_path']): os.makedirs(infer_config['data_path']) if not os.path.exists(infer_config['output_dir']): os.makedirs(infer_config['output_dir']) return infer_config
def main(): train_config_name = 'policy/bert/train.json' common_config_name = 'policy/bert/common.json' data_urls = { 'config.json': 'http://qiw2jpwfc.hn-bkt.clouddn.com/config.json', 'pytorch_model.bin': 'http://qiw2jpwfc.hn-bkt.clouddn.com/pytorch_model.bin', 'vocab.txt': 'http://qiw2jpwfc.hn-bkt.clouddn.com/vocab.txt', 'act_ontology.json': 'http://qiw2jpwfc.hn-bkt.clouddn.com/act_ontology.json', } train_config = update_config(common_config_name, train_config_name, 'crosswoz/policy_bert_data') train_config['raw_data_path'] = os.path.join(get_data_path(), 'crosswoz/raw') # download data for data_key, url in data_urls.items(): dst = os.path.join(train_config['data_path'], data_key) file_name = data_key.split('.')[0] train_config[file_name] = dst if not os.path.exists(dst): download_from_url(url, dst) pl.seed_everything(train_config['seed']) trainer = Trainer(train_config) trainer.train() # trainer.best_model_path = '/xhp/xbot/data/crosswoz/policy_bert_data/Epoch-6-f1-0.902' trainer.eval_test()
def main(): train_config_name = "policy/bert/train.json" common_config_name = "policy/bert/common.json" data_urls = { "config.json": "http://xbot.bslience.cn/bert-base-chinese/config.json", "pytorch_model.bin": "http://xbot.bslience.cn/bert-base-chinese/pytorch_model.bin", "vocab.txt": "http://xbot.bslience.cn/bert-base-chinese/vocab.txt", "act_ontology.json": "http://xbot.bslience.cn/act_ontology.json", } train_config = update_config(common_config_name, train_config_name, "crosswoz/policy_bert_data") train_config["raw_data_path"] = os.path.join(get_data_path(), "crosswoz/raw") # download data for data_key, url in data_urls.items(): dst = os.path.join(train_config["data_path"], data_key) file_name = data_key.split(".")[0] train_config[file_name] = dst if not os.path.exists(dst): download_from_url(url, dst) pl.seed_everything(train_config["seed"]) trainer = Trainer(train_config) trainer.train() trainer.eval_test()
def clean_ontology() -> None: """Clean ontology data.""" data_path = get_data_path() ontology_path = os.path.join(data_path, "crosswoz/dst_bert_data/ontology.json") ontology = json.load(open(ontology_path, "r", encoding="utf8")) cleaned_ontologies = {} facility = [] seps = ["、", ",", ",", ";", "或", ";", " "] for ds, values in tqdm(ontology.items()): if len(ds.split("-")) > 2: facility.append(ds.split("-")[-1]) continue if ds.split("-")[-1] == "酒店设施": continue cleaned_values = set() for value in values: multi_values = [value] for sep in seps: if sep in value: multi_values = value.split(sep) break for v in multi_values: v = "".join(v.split()) cleaned_values.add(v) cleaned_ontologies["酒店-酒店设施"] = facility cleaned_ontologies[ds] = list(cleaned_values) cleaned_ontologies_path = os.path.join( data_path, "crosswoz/dst_bert_data/cleaned_ontology.json") with open(cleaned_ontologies_path, "w", encoding="utf8") as f: json.dump(cleaned_ontologies, f, ensure_ascii=False, indent=2)
def load_act_ontology() -> Tuple[List[str], int]: """Load action ontology from cache. Returns: action ontology and numbers of action """ act_ontology = load_json( os.path.join(get_data_path(), "crosswoz/policy_bert_data/act_ontology.json")) num_act = len(act_ontology) return act_ontology, num_act
def eval_metrics(gold_pred_sys_das): tp, fp, fn = 0, 0, 0 joint_acc = total = 0 bad_case = {} for dia_id, sess in gold_pred_sys_das.items(): for turn_id, turn in sess.items(): if not turn['gold_sys_act'] and not turn['pred_sys_act']: joint_acc += 1 elif not turn['pred_sys_act']: fn += len(turn['gold_sys_act']) elif not turn['gold_sys_act']: fp += len(turn['pred_sys_act']) # 当 intent 为 Recommend 或者 slot 为 周边xx 时,数据集中给出的数量并没有规律, # 因此,只要碰上此类,都认为正确,预测的结果基本包含数据集中的结果 elif ((turn['gold_sys_act'][0][0] == turn['pred_sys_act'][0][0] == 'Recommend') or turn['gold_sys_act'][0][2].startswith('周边')): joint_acc += 1 tp += len(turn['gold_sys_act']) else: gold = set(turn['gold_sys_act']) pred = set(turn['pred_sys_act']) if gold != pred: if dia_id not in bad_case: bad_case[dia_id] = {} bad_case[dia_id][str(turn_id)] = { 'gold_sys_act': turn['gold_sys_act'], 'pred_sys_act': turn['pred_sys_act'] } else: joint_acc += 1 tp += len(gold & pred) fn += len(gold - pred) fp += len(pred - gold) total += 1 precision = tp / (tp + fp) if (tp + fp) != 0 else 0 recall = tp / (tp + fn) if (tp + fn) != 0 else 0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) != 0 else 0 joint_acc /= total output_path = os.path.join( get_data_path(), 'crosswoz/policy_rule_single_domain_data/bad_case.json') dump_json(bad_case, output_path) return f1, precision, recall, joint_acc
def main(): model_config_name = 'dst/bert/train.json' common_config_name = 'dst/bert/common.json' data_urls = { 'train4bert_dst.json': 'http://xbot.bslience.cn/train4bert_dst.json', 'dev4bert_dst.json': 'http://xbot.bslience.cn/dev4bert_dst.json', 'test4bert_dst.json': 'http://xbot.bslience.cn/test4bert_dst.json', 'cleaned_ontology.json': 'http://xbot.bslience.cn/cleaned_ontology.json', 'config.json': 'http://xbot.bslience.cn/bert-base-chinese/config.json', 'pytorch_model.bin': 'http://xbot.bslience.cn/bert-base-chinese/pytorch_model.bin', 'vocab.txt': 'http://xbot.bslience.cn/bert-base-chinese/vocab.txt' } # load config root_path = get_root_path() common_config_path = os.path.join(get_config_path(), common_config_name) train_config_path = os.path.join(get_config_path(), model_config_name) common_config = json.load(open(common_config_path)) train_config = json.load(open(train_config_path)) train_config.update(common_config) train_config['n_gpus'] = torch.cuda.device_count() train_config['train_batch_size'] = max( 1, train_config['n_gpus']) * train_config['train_batch_size'] train_config['device'] = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') train_config['data_path'] = os.path.join(get_data_path(), 'crosswoz/dst_bert_data') train_config['output_dir'] = os.path.join(root_path, train_config['output_dir']) if not os.path.exists(train_config['data_path']): os.makedirs(train_config['data_path']) if not os.path.exists(train_config['output_dir']): os.makedirs(train_config['output_dir']) # download data for data_key, url in data_urls.items(): dst = os.path.join(train_config['data_path'], data_key) file_name = data_key.split('.')[0] train_config[file_name] = dst if not os.path.exists(dst): download_from_url(url, dst) # train trainer = Trainer(train_config) trainer.train() trainer.eval_test() get_recall(train_config['data_path'])
def preprocess(mode): assert mode == 'all' or mode == 'usr' or mode == 'sys' # path data_key = ['train', 'val', 'test'] data = {} for key in data_key: # read crosswoz source data from json.zip data[key] = read_zipped_json( os.path.join(get_data_path(), 'crosswoz/raw', key + '.json.zip'), key + '.json') print('load {}, size {}'.format(key, len(data[key]))) # generate train, val, test dataset for key in data_key: sessions = [] for no, sess in data[key].items(): processed_data = OrderedDict() processed_data['sys-usr'] = sess['sys-usr'] processed_data['type'] = sess['type'] processed_data['task description'] = sess['task description'] messages = sess['messages'] processed_data['turns'] = [ OrderedDict({ 'role': message['role'], 'utterance': message['content'], 'dialog_act': message['dialog_act'] }) for message in messages ] sessions.append(processed_data) json.dump(sessions, open(os.path.join(get_data_path(), 'crosswoz/readable_data', f'readabe_{key}_data.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False, sort_keys=False)
def main(): model_config_name = "dst/bert/train.json" common_config_name = "dst/bert/common.json" data_urls = { "train4bert_dst.json": "http://xbot.bslience.cn/train4bert_dst.json", "dev4bert_dst.json": "http://xbot.bslience.cn/dev4bert_dst.json", "test4bert_dst.json": "http://xbot.bslience.cn/test4bert_dst.json", "cleaned_ontology.json": "http://xbot.bslience.cn/cleaned_ontology.json", "config.json": "http://xbot.bslience.cn/bert-base-chinese/config.json", "pytorch_model.bin": "http://xbot.bslience.cn/bert-base-chinese/pytorch_model.bin", "vocab.txt": "http://xbot.bslience.cn/bert-base-chinese/vocab.txt", } # load config root_path = get_root_path() common_config_path = os.path.join(get_config_path(), common_config_name) train_config_path = os.path.join(get_config_path(), model_config_name) common_config = json.load(open(common_config_path)) train_config = json.load(open(train_config_path)) train_config.update(common_config) train_config["n_gpus"] = torch.cuda.device_count() train_config["train_batch_size"] = ( max(1, train_config["n_gpus"]) * train_config["train_batch_size"] ) train_config["device"] = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) train_config["data_path"] = os.path.join(get_data_path(), "crosswoz/dst_bert_data") train_config["output_dir"] = os.path.join(root_path, train_config["output_dir"]) if not os.path.exists(train_config["data_path"]): os.makedirs(train_config["data_path"]) if not os.path.exists(train_config["output_dir"]): os.makedirs(train_config["output_dir"]) # download data for data_key, url in data_urls.items(): dst = os.path.join(train_config["data_path"], data_key) file_name = data_key.split(".")[0] train_config[file_name] = dst if not os.path.exists(dst): download_from_url(url, dst) # train trainer = Trainer(train_config) trainer.train() trainer.eval_test() get_recall(train_config["data_path"])
def main(): rule_dst = RuleDST() rule_policy = RulePolicy() train_path = os.path.join( get_data_path(), "crosswoz/policy_rule_single_domain_data/train.json") # train_path = os.path.join(get_data_path(), 'crosswoz/policy_rule_single_domain_data/single_bad_case.json') train_examples = json.load(open(train_path, encoding="utf8")) sys_state_action_pairs = {} for id_, dialogue in train_examples.items(): sys_state_action_pair = {} sess = dialogue["messages"] rule_dst.init_session() for i, turn in enumerate(sess): if turn["role"] == "usr": rule_dst.update(usr_da=turn["dialog_act"]) rule_dst.state["user_action"].clear() rule_dst.state["user_action"].extend(turn["dialog_act"]) if i + 2 == len(sess): rule_dst.state["terminated"] = True else: for domain, svs in turn["sys_state"].items(): for slot, value in svs.items(): if (slot != "selectedResults" and not rule_dst.state["belief_state"][domain] [slot]): rule_dst.state["belief_state"][domain][ slot] = value pred_sys_act = rule_policy.predict(rule_dst.state) sys_state_action_pair[str(i)] = { "gold_sys_act": [tuple(act) for act in turn["dialog_act"]], "pred_sys_act": [tuple(act) for act in pred_sys_act], } rule_dst.state["system_action"].clear() rule_dst.state["system_action"].extend(turn["dialog_act"]) sys_state_action_pairs[id_] = sys_state_action_pair f1, precision, recall, joint_acc = eval_metrics(sys_state_action_pairs) print( f"f1: {f1:.3f}, precision: {precision:.3f}, recall: {recall:.3f}, joint_acc: {joint_acc:.3f}" )
def main(): rule_dst = RuleDST() rule_policy = RulePolicy() train_path = os.path.join( get_data_path(), 'crosswoz/policy_rule_single_domain_data/train.json') # train_path = os.path.join(get_data_path(), 'crosswoz/policy_rule_single_domain_data/single_bad_case.json') train_examples = json.load(open(train_path, encoding='utf8')) sys_state_action_pairs = {} for id_, dialogue in train_examples.items(): sys_state_action_pair = {} sess = dialogue['messages'] rule_dst.init_session() for i, turn in enumerate(sess): if turn['role'] == 'usr': rule_dst.update(usr_da=turn['dialog_act']) rule_dst.state['user_action'].clear() rule_dst.state['user_action'].extend(turn['dialog_act']) if i + 2 == len(sess): rule_dst.state['terminated'] = True else: for domain, svs in turn['sys_state'].items(): for slot, value in svs.items(): if slot != 'selectedResults' and not rule_dst.state[ 'belief_state'][domain][slot]: rule_dst.state['belief_state'][domain][ slot] = value pred_sys_act = rule_policy.predict(rule_dst.state) sys_state_action_pair[str(i)] = { 'gold_sys_act': [tuple(act) for act in turn['dialog_act']], 'pred_sys_act': [tuple(act) for act in pred_sys_act] } rule_dst.state['system_action'].clear() rule_dst.state['system_action'].extend(turn['dialog_act']) sys_state_action_pairs[id_] = sys_state_action_pair f1, precision, recall, joint_acc = eval_metrics(sys_state_action_pairs) print( f'f1: {f1:.3f}, precision: {precision:.3f}, recall: {recall:.3f}, joint_acc: {joint_acc:.3f}' )
def load_config() -> dict: """Load config for inference. Returns: config dict """ common_config_path = os.path.join(get_config_path(), BertPolicy.common_config_name) infer_config_path = os.path.join(get_config_path(), BertPolicy.inference_config_name) common_config = load_json(common_config_path) infer_config = load_json(infer_config_path) infer_config.update(common_config) infer_config["device"] = torch.device( "cuda" if torch.cuda.is_available() else "cpu") infer_config["data_path"] = os.path.join(get_data_path(), "crosswoz/policy_bert_data") if not os.path.exists(infer_config["data_path"]): os.makedirs(infer_config["data_path"]) return infer_config
def update_config(common_config_name, train_config_name, task_path): root_path = get_root_path() common_config_path = os.path.join(get_config_path(), common_config_name) train_config_path = os.path.join(get_config_path(), train_config_name) common_config = json.load(open(common_config_path)) train_config = json.load(open(train_config_path)) train_config.update(common_config) train_config["n_gpus"] = torch.cuda.device_count() train_config["train_batch_size"] = (max(1, train_config["n_gpus"]) * train_config["train_batch_size"]) train_config["device"] = torch.device( "cuda" if torch.cuda.is_available() else "cpu") train_config["data_path"] = os.path.join(get_data_path(), task_path) train_config["output_dir"] = os.path.join(root_path, train_config["output_dir"]) if not os.path.exists(train_config["data_path"]): os.makedirs(train_config["data_path"]) if not os.path.exists(train_config["output_dir"]): os.makedirs(train_config["output_dir"]) return train_config
def __init__(self): super(MLEPolicy, self).__init__() # load config common_config_path = os.path.join(get_config_path(), MLEPolicy.common_config_name) common_config = json.load(open(common_config_path)) model_config_path = os.path.join(get_config_path(), MLEPolicy.model_config_name) model_config = json.load(open(model_config_path)) model_config.update(common_config) self.model_config = model_config self.model_config["data_path"] = os.path.join( get_data_path(), "crosswoz/policy_mle_data") self.model_config["n_gpus"] = (0 if self.model_config["device"] == "cpu" else torch.cuda.device_count()) self.model_config["device"] = torch.device(self.model_config["device"]) # download data for model_key, url in MLEPolicy.model_urls.items(): dst = os.path.join(self.model_config["data_path"], model_key) file_name = (model_key.split(".")[0] if not model_key.endswith("pth") else "trained_model_path") self.model_config[file_name] = dst if not os.path.exists(dst) or not self.model_config["use_cache"]: download_from_url(url, dst) self.vector = CrossWozVector( sys_da_voc_json=self.model_config["sys_da_voc"], usr_da_voc_json=self.model_config["usr_da_voc"], ) policy = MultiDiscretePolicy(self.vector.state_dim, model_config["hidden_size"], self.vector.sys_da_dim) policy.load_state_dict( torch.load(self.model_config["trained_model_path"])) self.policy = policy.to(self.model_config["device"]).eval() print(f'>>> {self.model_config["trained_model_path"]} loaded ...')
def __init__(self): # path root_path = get_root_path() config_file = os.path.join( get_config_path(), IntentWithBertPredictor.default_model_config) # load config config = json.load(open(config_file)) self.device = config["DEVICE"] # load intent vocabulary and dataloader intent_vocab = json.load( open( os.path.join(get_data_path(), "crosswoz/nlu_intent_data/intent_vocab.json"), encoding="utf-8", )) dataloader = Dataloader( intent_vocab=intent_vocab, pretrained_weights=config["model"]["pretrained_weights"], ) # load best model best_model_path = os.path.join( os.path.join(root_path, DEFAULT_MODEL_PATH), IntentWithBertPredictor.default_model_name, ) # best_model_path = os.path.join(DEFAULT_MODEL_PATH, IntentWithBertPredictor.default_model_name) if not os.path.exists(best_model_path): download_from_url(IntentWithBertPredictor.default_model_url, best_model_path) model = IntentWithBert(config["model"], self.device, dataloader.intent_dim) model.load_state_dict( torch.load(best_model_path, map_location=self.device)) model.to(self.device) model.eval() self.model = model self.dataloader = dataloader print(f"{best_model_path} loaded - {best_model_path}")
def load_config() -> dict: """Load config from common config and inference config from src/xbot/config/dst/bert . Returns: config dict """ root_path = get_root_path() common_config_path = os.path.join(get_config_path(), BertDST.common_config_name) infer_config_path = os.path.join(get_config_path(), BertDST.infer_config_name) common_config = json.load(open(common_config_path)) infer_config = json.load(open(infer_config_path)) infer_config.update(common_config) infer_config["device"] = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) infer_config["data_path"] = os.path.join( get_data_path(), "crosswoz/dst_bert_data" ) infer_config["output_dir"] = os.path.join(root_path, infer_config["output_dir"]) if not os.path.exists(infer_config["data_path"]): os.makedirs(infer_config["data_path"]) if not os.path.exists(infer_config["output_dir"]): os.makedirs(infer_config["output_dir"]) return infer_config
if __name__ == '__main__': data_urls = { 'intent_train_data.json': 'http://qiw2jpwfc.hn-bkt.clouddn.com/intent_train_data.json', 'intent_val_data.json': 'http://qiw2jpwfc.hn-bkt.clouddn.com/intent_val_data.json', 'intent_test_data.json': 'http://qiw2jpwfc.hn-bkt.clouddn.com/intent_test_data.json' } # load config root_path = get_root_path() config_path = os.path.join(get_config_path(), 'crosswoz_all_context_nlu_intent.json') config = json.load(open(config_path)) data_path = os.path.join(get_data_path(), 'crosswoz/nlu_intent_data/') output_dir = config['output_dir'] output_dir = os.path.join(root_path, output_dir) log_dir = config['log_dir'] log_dir = os.path.join(root_path, log_dir) device = config['DEVICE'] # download data for data_key, url in data_urls.items(): dst = os.path.join(os.path.join(data_path, data_key)) if not os.path.exists(dst): download_from_url(url, dst) # seed set_seed(config['seed'])
if cur_token == 'EOS': break resp_tokens.append(cur_token) slot_value = ' '.join(resp_tokens) if slot_value != 'none': predict_belief.append((self.all_slots[i], slot_value)) self.update_belief_state(predict_belief) if __name__ == '__main__': import random dst_model = TradeDST() data_path = os.path.join(get_data_path(), 'crosswoz/dst_trade_data') dials_path = os.path.join(data_path, 'dev_dials.json') # download dials file if not os.path.exists(dials_path): download_from_url('http://qiw2jpwfc.hn-bkt.clouddn.com/dev_dials.json', dials_path) with open(os.path.join(data_path, 'dev_dials.json'), 'r', encoding='utf8') as f: dials = json.load(f) example = random.choice(dials) break_turn = 0 for ti, turn in enumerate(example['dialogue']): dst_model.state['history'].append(('sys', turn['system_transcript'])) dst_model.state['history'].append(('usr', turn['transcript'])) if random.random() < 0.5: break_turn = ti + 1 break if break_turn == len(example['dialogue']):
value = metro[1]["地铁"] if value is not None: break return value if __name__ == "__main__": from xbot.dm.dst.rule_dst.rule import RuleDST from xbot.util.path import get_data_path from xbot.util.file_util import read_zipped_json from script.policy.rule.rule_test import eval_metrics from tqdm import tqdm rule_dst = RuleDST() bert_policy = BertPolicy() train_path = os.path.join(get_data_path(), "crosswoz/raw/train.json.zip") train_examples = read_zipped_json(train_path, "train.json") sys_state_action_pairs = {} for id_, dialogue in tqdm(train_examples.items()): sys_state_action_pair = {} sess = dialogue["messages"] rule_dst.init_session() for i, turn in enumerate(sess): if turn["role"] == "usr": rule_dst.update(usr_da=turn["dialog_act"]) rule_dst.state["user_action"].clear() rule_dst.state["user_action"].extend(turn["dialog_act"]) rule_dst.state["history"].append(["usr", turn["content"]]) if i + 2 == len(sess): rule_dst.state["terminated"] = True
def __init__(self): super(TradeDST, self).__init__() # load config common_config_path = os.path.join(get_config_path(), TradeDST.common_config_name) common_config = json.load(open(common_config_path)) model_config_path = os.path.join(get_config_path(), TradeDST.model_config_name) model_config = json.load(open(model_config_path)) model_config.update(common_config) self.model_config = model_config self.model_config["data_path"] = os.path.join( get_data_path(), "crosswoz/dst_trade_data") self.model_config["n_gpus"] = (0 if self.model_config["device"] == "cpu" else torch.cuda.device_count()) self.model_config["device"] = torch.device(self.model_config["device"]) if model_config["load_embedding"]: model_config["hidden_size"] = 300 # download data for model_key, url in TradeDST.model_urls.items(): dst = os.path.join(self.model_config["data_path"], model_key) if model_key.endswith("pth"): file_name = "trained_model_path" elif model_key.endswith("pkl"): file_name = model_key.rsplit("-", maxsplit=1)[0] else: file_name = model_key.split(".")[0] # ontology self.model_config[file_name] = dst if not os.path.exists(dst) or not self.model_config["use_cache"]: download_from_url(url, dst) # load date & model ontology = json.load( open(self.model_config["ontology"], "r", encoding="utf8")) self.all_slots = get_slot_information(ontology) self.gate2id = {"ptr": 0, "none": 1} self.id2gate = {id_: gate for gate, id_ in self.gate2id.items()} self.lang = pickle.load(open(self.model_config["lang"], "rb")) self.mem_lang = pickle.load(open(self.model_config["mem-lang"], "rb")) model = Trade( lang=self.lang, vocab_size=len(self.lang.index2word), hidden_size=self.model_config["hidden_size"], dropout=self.model_config["dropout"], num_encoder_layers=self.model_config["num_encoder_layers"], num_decoder_layers=self.model_config["num_decoder_layers"], pad_id=self.model_config["pad_id"], slots=self.all_slots, num_gates=len(self.gate2id), unk_mask=self.model_config["unk_mask"], ) model.load_state_dict( torch.load(self.model_config["trained_model_path"])) self.model = model.to(self.model_config["device"]).eval() print(f'>>> {self.model_config["trained_model_path"]} loaded ...') self.state = default_state() print(">>> State initialized ...")
output_path: save path of action ontology file """ raw_data = read_zipped_json(raw_data_path, "train.json") act_ontology = set() for dial_id, dial in tqdm(raw_data.items(), desc="Generate action ontology ..."): for turn_id, turn in enumerate(dial["messages"]): if turn["role"] == "sys" and turn["dialog_act"]: for da in turn["dialog_act"]: act = "-".join([da[1], da[0], da[2]]) act_ontology.add(act) dump_json(list(act_ontology), output_path) data_path = get_data_path() raw_data_path = os.path.join(data_path, "crosswoz/raw/train.json.zip") output_path = os.path.join(data_path, "crosswoz/policy_bert_data/act_ontology.json") get_act_ontology(raw_data_path, output_path) ACT_ONTOLOGY, NUM_ACT = load_act_ontology() def preprocess(raw_data_path: str, output_path: str, data_type: str) -> List[Dict[str, list]]: """Preprocess raw data to generate model inputs. Args: raw_data_path: raw (train, dev, tests) data path output_path: save path of precessed data file