def __init__(self, data_dir, config): self.time_step = 0 self.cfg = config self.db = DBQuery(data_dir, config) self.topic = '' self.evaluator = MultiWozEvaluator(data_dir) self.lock_evalutor = False
def create_dataset_global(self, part, file_dir, data_dir, cfg, db): datas = self.data[part] goals = self.goal[part] s_usr, s_sys, r_g, next_s_usr, next_s_sys, t = [], [], [], [], [], [] evaluator = MultiWozEvaluator(data_dir) for idx, turn_data in enumerate(datas): if turn_data['others']['turn'] % 2 == 0: if turn_data['others']['turn'] == 0: current_goal = goals[turn_data['others']['session_id']] evaluator.add_goal(current_goal) else: next_s_usr.append(s_usr[-1]) if turn_data['others']['change'] and evaluator.cur_domain: if 'final' in current_goal[evaluator.cur_domain]: for key in current_goal[evaluator.cur_domain]['final']: current_goal[evaluator.cur_domain][key] = current_goal[evaluator.cur_domain]['final'][key] del(current_goal[evaluator.cur_domain]['final']) turn_data['user_goal'] = deepcopy(current_goal) s_usr.append(torch.Tensor(state_vectorize_user(turn_data, cfg, evaluator.cur_domain))) evaluator.add_usr_da(turn_data['trg_user_action']) if turn_data['others']['terminal']: next_turn_data = deepcopy(turn_data) next_turn_data['others']['turn'] = -1 next_turn_data['user_action'] = turn_data['trg_user_action'] next_turn_data['sys_action'] = datas[idx+1]['trg_sys_action'] next_turn_data['trg_user_action'] = {} next_turn_data['goal_state'] = datas[idx+1]['final_goal_state'] next_s_usr.append(torch.Tensor(state_vectorize_user(next_turn_data, cfg, evaluator.cur_domain))) else: if turn_data['others']['turn'] != 1: next_s_sys.append(s_sys[-1]) s_sys.append(torch.Tensor(state_vectorize(turn_data, cfg, db, True))) evaluator.add_sys_da(turn_data['trg_sys_action']) if turn_data['others']['terminal']: next_turn_data = deepcopy(turn_data) next_turn_data['others']['turn'] = -1 next_turn_data['user_action'] = {} next_turn_data['sys_action'] = turn_data['trg_sys_action'] next_turn_data['trg_sys_action'] = {} next_turn_data['belief_state'] = turn_data['final_belief_state'] next_s_sys.append(torch.Tensor(state_vectorize(next_turn_data, cfg, db, True))) reward_g = 20 if evaluator.task_success() else -5 r_g.append(reward_g) t.append(1) else: reward_g = 5 if evaluator.cur_domain and evaluator.domain_success(evaluator.cur_domain) else -1 r_g.append(reward_g) t.append(0) torch.save((s_usr, s_sys, r_g, next_s_usr, next_s_sys, t), file_dir)
def create_dataset_sys(self, part, file_dir, data_dir, cfg, db): datas = self.data[part] goals = self.goal[part] s, a, r, next_s, t = [], [], [], [], [] evaluator = MultiWozEvaluator(data_dir) for idx, turn_data in enumerate(datas): if turn_data['others']['turn'] % 2 == 0: if turn_data['others']['turn'] == 0: evaluator.add_goal( goals[turn_data['others']['session_id']]) evaluator.add_usr_da(turn_data['trg_user_action']) continue if turn_data['others']['turn'] != 1: next_s.append(s[-1]) s.append(torch.Tensor(state_vectorize(turn_data, cfg, db, True))) a.append( torch.Tensor(action_vectorize(turn_data['trg_sys_action'], cfg))) evaluator.add_sys_da(turn_data['trg_sys_action']) if turn_data['others']['terminal']: next_turn_data = deepcopy(turn_data) next_turn_data['others']['turn'] = -1 next_turn_data['user_action'] = {} next_turn_data['sys_action'] = turn_data['trg_sys_action'] next_turn_data['trg_sys_action'] = {} next_turn_data['belief_state'] = turn_data[ 'final_belief_state'] next_s.append( torch.Tensor(state_vectorize(next_turn_data, cfg, db, True))) reward = 20 if evaluator.task_success(False) else -5 r.append(reward) t.append(1) else: reward = 0 if evaluator.cur_domain: for slot, value in turn_data['belief_state'][ evaluator.cur_domain].items(): if value == '?': for da in turn_data['trg_sys_action']: d, i, k, p = da.split('-') if i in [ 'inform', 'recommend', 'offerbook', 'offerbooked' ] and k == slot: break else: # not answer request reward -= 1 if not turn_data['trg_sys_action']: reward -= 5 r.append(reward) t.append(0) torch.save((s, a, r, next_s, t), file_dir)
def create_dataset_usr(self, part, file_dir, data_dir, cfg, db): datas = self.data[part] goals = self.goal[part] s, a, r, next_s, t = [], [], [], [], [] evaluator = MultiWozEvaluator(data_dir) current_goal = None for idx, turn_data in enumerate(datas): if turn_data['others']['turn'] % 2 == 1: evaluator.add_sys_da(turn_data['trg_sys_action']) continue if turn_data['others']['turn'] == 0: current_goal = goals[turn_data['others']['session_id']] evaluator.add_goal(current_goal) else: next_s.append(s[-1]) if turn_data['others']['change'] and evaluator.cur_domain: if 'final' in current_goal[evaluator.cur_domain]: for key in current_goal[evaluator.cur_domain]['final']: current_goal[evaluator.cur_domain][key] = current_goal[evaluator.cur_domain]['final'][key] del(current_goal[evaluator.cur_domain]['final']) turn_data['user_goal'] = deepcopy(current_goal) s.append(torch.Tensor(state_vectorize_user(turn_data, cfg, evaluator.cur_domain))) a.append(torch.Tensor(action_vectorize_user(turn_data['trg_user_action'], turn_data['others']['terminal'], cfg))) evaluator.add_usr_da(turn_data['trg_user_action']) if turn_data['others']['terminal']: next_turn_data = deepcopy(turn_data) next_turn_data['others']['turn'] = -1 next_turn_data['user_action'] = turn_data['trg_user_action'] next_turn_data['sys_action'] = datas[idx+1]['trg_sys_action'] next_turn_data['trg_user_action'] = {} next_turn_data['goal_state'] = datas[idx+1]['final_goal_state'] next_s.append(torch.Tensor(state_vectorize_user(next_turn_data, cfg, evaluator.cur_domain))) reward = 20 if evaluator.inform_F1(ansbysys=False)[1] == 1. else -5 r.append(reward) t.append(1) else: reward = 0 if evaluator.cur_domain: for da in turn_data['trg_user_action']: d, i, k = da.split('-') if i == 'request': for slot, value in turn_data['goal_state'][d].items(): if value != '?' and slot in turn_data['user_goal'][d]\ and turn_data['user_goal'][d][slot] != '?': # request before express constraint reward -= 1 if not turn_data['trg_user_action']: reward -= 5 r.append(reward) t.append(0) torch.save((s, a, r, next_s, t), file_dir)
def __init__(self, env_cls, args, manager, cfg, process_num, character, pre=False, infer=False): """ 专门用于更新预训练模型 :param env_cls: env class or function, not instance, as we need to create several instance in class. :param args: :param manager: :param cfg: :param process_num: process number :param character: user or system :param pre: set to pretrain mode :param infer: set to test mode """ self.process_num = process_num self.character = character # initialize envs for each process self.env_list = [] for _ in range(process_num): self.env_list.append(env_cls()) # construct policy and value network self.policy = MultiDiscretePolicy(cfg, character).to(device=DEVICE) if pre: self.print_per_batch = args.print_per_batch from dbquery import DBQuery db = DBQuery(args.data_dir, cfg) self.data_train = manager.create_dataset_policy('train', args.batchsz, cfg, db, character) self.data_valid = manager.create_dataset_policy('valid', args.batchsz, cfg, db, character) self.data_test = manager.create_dataset_policy('test', args.batchsz, cfg, db, character) if character == 'sys': pos_weight = args.policy_weight_sys * torch.ones([cfg.a_dim]).to(device=DEVICE) elif character == 'usr': pos_weight = args.policy_weight_usr * torch.ones([cfg.a_dim_usr]).to(device=DEVICE) else: raise Exception('Unknown character') self.multi_entropy_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight) else: self.evaluator = MultiWozEvaluator(args.data_dir, cfg.d) self.save_dir = args.save_dir + '/' + character if pre else args.save_dir self.save_per_epoch = args.save_per_epoch self.optim_batchsz = args.batchsz self.policy.eval() self.gamma = args.gamma self.policy_optim = optim.RMSprop(self.policy.parameters(), lr=args.lr_policy, weight_decay=args.weight_decay) self.writer = SummaryWriter()
class StateTracker(object): def __init__(self, data_dir, config): self.time_step = 0 self.cfg = config self.db = DBQuery(data_dir, config) self.topic = '' self.evaluator = MultiWozEvaluator(data_dir) self.lock_evalutor = False def set_rollout(self, rollout): if rollout: self.save_time_step = self.time_step self.save_topic = self.topic self.lock_evalutor = True else: self.time_step = self.save_time_step self.save_topic = self.topic self.lock_evalutor = False def get_entities(self, s, domain): origin = s['belief_state'][domain].items() constraint = [] for k, v in origin: if v != '?' and k in self.cfg.mapping[domain]: constraint.append((self.cfg.mapping[domain][k], v)) entities = self.db.query(domain, constraint) random.shuffle(entities) return entities def update_belief_sys(self, old_s, a): """ update belief/goal state with sys action """ s = deepcopy(old_s) a_index = torch.nonzero(a) # get multiple da indices self.time_step += 1 s['others']['turn'] = self.time_step # update sys/user dialog act s['sys_action'] = dict() # update belief part das = [self.cfg.idx2da[idx.item()] for idx in a_index] das = [da.split('-') for da in das] sorted(das, key=lambda x: x[0]) # sort by domain entities = [] if self.topic == '' else self.get_entities(s, self.topic) return_flag = False for domain, intent, slot, p in das: if domain in self.cfg.belief_domains and domain != self.topic: self.topic = domain entities = self.get_entities(s, domain) da = '-'.join((domain, intent, slot, p)) if intent == 'request': s['sys_action'][da] = '?' elif intent in ['nooffer', 'nobook'] and self.topic != '': return_flag = True if slot in s['belief_state'][self.topic] and s['belief_state'][ self.topic][slot] != '?': s['sys_action'][da] = s['belief_state'][self.topic][slot] else: s['sys_action'][da] = 'none' elif slot == 'choice': s['sys_action'][da] = str(len(entities)) elif slot == 'none': s['sys_action'][da] = 'none' else: num = int(p) - 1 if self.topic and len( entities) > num and slot in self.cfg.mapping[ self.topic]: typ = self.cfg.mapping[self.topic][slot] if typ in entities[num]: s['sys_action'][da] = entities[num][typ] else: s['sys_action'][da] = 'none' else: s['sys_action'][da] = 'none' if not self.topic: continue if intent in [ 'inform', 'recommend', 'offerbook', 'offerbooked', 'book' ]: discard(s['belief_state'][self.topic], slot, '?') if slot in s['user_goal'][self.topic] and s['user_goal'][ self.topic][slot] == '?': s['goal_state'][self.topic][slot] = s['sys_action'][da] # booked if intent == 'inform' and slot == 'car': # taxi if 'booked' not in s['belief_state']['taxi']: s['belief_state']['taxi']['booked'] = 'taxi-booked' elif intent in ['offerbooked', 'book' ] and slot == 'ref': # train if self.topic in ['taxi', 'hospital', 'police']: s['belief_state'][ self.topic]['booked'] = f'{self.topic}-booked' s['sys_action'][da] = f'{self.topic}-booked' elif entities: book_domain = entities[0]['ref'].split('-')[0] if 'booked' not in s['belief_state'][ book_domain] and entities: s['belief_state'][book_domain][ 'booked'] = entities[0]['ref'] s['sys_action'][da] = entities[0]['ref'] if return_flag: for da in s['user_action']: d_usr, i_usr, s_usr = da.split('-') if i_usr == 'inform' and d_usr == self.topic: discard(s['belief_state'][d_usr], s_usr) reload(s['goal_state'], s['user_goal'], self.topic) if not self.lock_evalutor: self.evaluator.add_sys_da(s['sys_action']) return s def update_belief_usr(self, old_s, a): """ update belief/goal state with user action """ s = deepcopy(old_s) a_index = torch.nonzero(a) # get multiple da indices self.time_step += 1 s['others']['turn'] = self.time_step s['others']['terminal'] = 1 if (self.cfg.a_dim_usr - 1) in a_index else 0 # update sys/user dialog act s['user_action'] = dict() # update belief part das = [ self.cfg.idx2da_u[idx.item()] for idx in a_index if idx.item() != self.cfg.a_dim_usr - 1 ] das = [da.split('-') for da in das] if s['invisible_domains']: for da in das: if da[0] == s['next_available_domain']: s['next_available_domain'] = s['invisible_domains'][0] s['invisible_domains'].remove(s['next_available_domain']) break sorted(das, key=lambda x: x[0]) # sort by domain for domain, intent, slot in das: if domain in self.cfg.belief_domains and domain != self.topic: self.topic = domain da = '-'.join((domain, intent, slot)) if intent == 'request': s['user_action'][da] = '?' s['belief_state'][self.topic][slot] = '?' elif slot == 'none': s['user_action'][da] = 'none' else: if self.topic and slot in s['user_goal'][ self.topic] and s['user_goal'][domain][slot] != '?': s['user_action'][da] = s['user_goal'][domain][slot] else: s['user_action'][da] = 'dont care' if not self.topic: continue if intent == 'inform': s['belief_state'][domain][slot] = s['user_action'][da] if slot in s['user_goal'][self.topic] and s['user_goal'][ self.topic][slot] != '?': discard(s['goal_state'][self.topic], slot) if not self.lock_evalutor: self.evaluator.add_usr_da(s['user_action']) return s def reset(self, random_seed=None): """ Args: random_seed (int): Returns: init_state (dict): """ pass def step(self, s, sys_a): """ Args: s (dict): sys_a (vector): Returns: next_s (dict): terminal (bool): """ pass
def create_dataset_global(self, part, file_dir, data_dir, cfg, db): """ 创建global数据,这个数据记录了用户侧和系统侧的所有状态以及奖励 """ datas = self.data[part] goals = self.goal[part] s_usr, s_sys, r_g, next_s_usr, next_s_sys, t = [], [], [], [], [], [] evaluator = MultiWozEvaluator(data_dir, cfg.d) for idx, turn_data in enumerate(datas): if turn_data['others']['turn'] % 2 == 0: if turn_data['others']['turn'] == 0: current_goal = goals[turn_data['others']['session_id']] evaluator.add_goal(current_goal) else: next_s_usr.append(s_usr[-1]) # 当用户目标无法满足时,切换用户目标 if turn_data['others']['change'] and evaluator.cur_domain: if 'final' in current_goal[evaluator.cur_domain]: for key in current_goal[evaluator.cur_domain]['final']: current_goal[ evaluator.cur_domain][key] = current_goal[ evaluator.cur_domain]['final'][key] del (current_goal[evaluator.cur_domain]['final']) turn_data['user_goal'] = deepcopy(current_goal) s_usr.append( torch.Tensor( state_vectorize_user(turn_data, cfg, evaluator.cur_domain))) evaluator.add_usr_da(turn_data['trg_user_action']) if turn_data['others']['terminal']: next_turn_data = deepcopy(turn_data) next_turn_data['others']['turn'] = -1 next_turn_data['user_action'] = turn_data[ 'trg_user_action'] next_turn_data['sys_action'] = datas[idx + 1]['trg_sys_action'] next_turn_data['trg_user_action'] = {} next_turn_data['goal_state'] = datas[idx + 1]['final_goal_state'] next_s_usr.append( torch.Tensor( state_vectorize_user(next_turn_data, cfg, evaluator.cur_domain))) else: if turn_data['others']['turn'] != 1: next_s_sys.append(s_sys[-1]) s_sys.append( torch.Tensor(state_vectorize(turn_data, cfg, db, True))) evaluator.add_sys_da(turn_data['trg_sys_action']) if turn_data['others']['terminal']: next_turn_data = deepcopy(turn_data) next_turn_data['others']['turn'] = -1 next_turn_data['user_action'] = {} next_turn_data['sys_action'] = turn_data['trg_sys_action'] next_turn_data['trg_sys_action'] = {} next_turn_data['belief_state'] = turn_data[ 'final_belief_state'] next_s_sys.append( torch.Tensor( state_vectorize(next_turn_data, cfg, db, True))) # 由于多轮对话系统,默认最终都是系统说结束语,因此通过系统判断任务是否成功作为整体的奖励 reward_g = 20 if evaluator.task_success() else -5 r_g.append(reward_g) t.append(1) else: # 增加domain_success的奖励,其他的则每增加一轮减少一点损失,用于缩短轮数 todo 什么是 domain_success reward_g = 5 if evaluator.cur_domain and evaluator.domain_success( evaluator.cur_domain) else -1 r_g.append(reward_g) t.append(0) torch.save((s_usr, s_sys, r_g, next_s_usr, next_s_sys, t), file_dir)
def create_dataset_sys(self, part, file_dir, data_dir, cfg, db): """ 创建sys的训练数据 """ datas = self.data[part] goals = self.goal[part] # 系统状态+系统动作+回报+上一轮系统状态+末轮标志位 s, a, r, next_s, t = [], [], [], [], [] # evaluator 全称记录数据 evaluator = MultiWozEvaluator(data_dir, cfg.d) for idx, turn_data in enumerate(datas): # user # 用户侧并没有做数据的更新操作 if turn_data['others']['turn'] % 2 == 0: # 首轮对话加载用户目标 if turn_data['others']['turn'] == 0: evaluator.add_goal( goals[turn_data['others']['session_id']]) # evaluator.add_usr_da(turn_data['trg_user_action']) continue # 错位了,确实表示的下一轮状态 if turn_data['others']['turn'] != 1: next_s.append(s[-1]) # 将当前数据转化为状态向量 s.append(torch.Tensor(state_vectorize(turn_data, cfg, db, True))) # 将当前动作转化为动作向量 a.append( torch.Tensor(action_vectorize(turn_data['trg_sys_action'], cfg))) evaluator.add_sys_da(turn_data['trg_sys_action']) if turn_data['others']['terminal']: # 结束轮 next_turn_data = deepcopy(turn_data) next_turn_data['others']['turn'] = -1 next_turn_data['user_action'] = {} next_turn_data['sys_action'] = turn_data['trg_sys_action'] next_turn_data['trg_sys_action'] = {} next_turn_data['belief_state'] = turn_data[ 'final_belief_state'] # 统计next_s next_s.append( torch.Tensor(state_vectorize(next_turn_data, cfg, db, True))) # 统计奖励, 对于系统动作,判决任务是否完成作为最终奖励依据, # 系统是否完成了真实用户动作所提出的订阅请求,且系统是否回答了真实用户动作所咨询的所有问题 reward = 20 if evaluator.task_success(False) else -5 r.append(reward) # 结束标志位 t.append(1) else: reward = 0 if evaluator.cur_domain: for slot, value in turn_data['belief_state'][ evaluator.cur_domain].items(): if value == '?': for da in turn_data['trg_sys_action']: d, i, k, p = da.split('-') if i in [ 'inform', 'recommend', 'offerbook', 'offerbooked' ] and k == slot: break else: # not answer request # 没有完成对belief_state中的提问,奖励减一 reward -= 1 if not turn_data['trg_sys_action']: # 本轮没有回复奖励减五 reward -= 5 r.append(reward) t.append(0) torch.save((s, a, r, next_s, t), file_dir)