def generate(total_num=1000, seed=42, output_file='goal.json'): random.seed(seed) np.random.seed(seed) goal_generator = GoalGenerator() goals = [] avg_domains = [] while len(goals) < total_num: goal = goal_generator.get_user_goal() # pprint(goal) if 'police' in goal['domain_ordering']: no_police = list(goal['domain_ordering']) no_police.remove('police') goal['domain_ordering'] = tuple(no_police) del goal['police'] try: message = goal_generator.build_message(goal)[1] except: continue # print(message) avg_domains.append(len(goal['domain_ordering'])) goals.append({ "goals": [], "ori_goals": goal, "description": message, "timestamp": str(datetime.datetime.now()), "ID": len(goals) }) print('avg domains:', np.mean(avg_domains)) # avg domains: 1.846 json.dump(goals, open(output_file, 'w'), indent=4)
def test_generate_overlap(total_num=1000, seed=42, output_file='goal.json'): train_data = read_zipped_json('../../../data/multiwoz/train.json.zip', 'train.json') train_serialized_goals = [] for d in train_data: train_serialized_goals.append( extract_slot_combination_from_goal(train_data[d]['goal'])) test_data = read_zipped_json('../../../data/multiwoz/test.json.zip', 'test.json') test_serialized_goals = [] for d in test_data: test_serialized_goals.append( extract_slot_combination_from_goal(test_data[d]['goal'])) overlap = 0 for serialized_goal in test_serialized_goals: if serialized_goal in train_serialized_goals: overlap += 1 print(len(train_serialized_goals), len(test_serialized_goals), overlap) # 8434 1000 430 random.seed(seed) np.random.seed(seed) goal_generator = GoalGenerator() goals = [] avg_domains = [] serialized_goals = [] while len(goals) < total_num: goal = goal_generator.get_user_goal() # pprint(goal) if 'police' in goal['domain_ordering']: no_police = list(goal['domain_ordering']) no_police.remove('police') goal['domain_ordering'] = tuple(no_police) del goal['police'] try: message = goal_generator.build_message(goal)[1] except: continue # print(message) avg_domains.append(len(goal['domain_ordering'])) goals.append({ "goals": [], "ori_goals": goal, "description": message, "timestamp": str(datetime.datetime.now()), "ID": len(goals) }) serialized_goals.append(extract_slot_combination_from_goal(goal)) if len(serialized_goals) == 1: print(serialized_goals) overlap = 0 for serialized_goal in serialized_goals: if serialized_goal in train_serialized_goals: overlap += 1 print(len(train_serialized_goals), len(serialized_goals), overlap) # 8434 1000 199
def __init__(self, archive_file=DEFAULT_ARCHIVE_FILE, model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/vhus_simulator_multiwoz.zip'): with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f: config = json.load(f) manager = UserDataManager() voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size() self.user = VHUS(config, voc_goal_size, voc_usr_size, voc_sys_size).to(device=DEVICE) self.goal_gen = GoalGenerator() self.manager = manager self.user.eval() self.load(archive_file, model_file, config['load'])
def __init__(self): """ Constructor for User_Policy_Agenda class. """ self.max_turn = 40 self.max_initiative = 4 self.goal_generator = GoalGenerator() self.__turn = 0 self.goal = None self.agenda = None Policy.__init__(self)
def __init__(self, goal_generator: GoalGenerator): """ create new Goal by random Args: goal_generator (GoalGenerator): Goal Generator. """ self.domain_goals = goal_generator.get_user_goal() self.domains = list(self.domain_goals['domain_ordering']) del self.domain_goals['domain_ordering'] for domain in self.domains: if 'reqt' in self.domain_goals[domain].keys(): self.domain_goals[domain]['reqt'] = {slot: DEF_VAL_UNK for slot in self.domain_goals[domain]['reqt']} if 'book' in self.domain_goals[domain].keys(): self.domain_goals[domain]['booked'] = DEF_VAL_UNK
# evaluator = MultiWozEvaluator() # sess = BiSession(sys_agent=sys_agent, user_agent=user_agent, kb_query=None, evaluator=evaluator) # user_policy = UserPolicyAgendaMultiWoz() # # sys_policy = RuleBasedMultiwozBot() # # user_nlg = TemplateNLG(is_user=True, mode='manual') # sys_nlg = TemplateNLG(is_user=False, mode='manual') # # dst = RuleDST() # # user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json', # model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip') # goal_generator = GoalGenerator() # while True: # goal = goal_generator.get_user_goal() # if 'restaurant' in goal['domain_ordering'] and 'hotel' in goal['domain_ordering']: # break # # pprint(goal) user_goal = { 'domain_ordering': ('restaurant', 'hotel', 'taxi'), 'hotel': { 'book': { 'day': 'sunday', 'people': '6', 'stay': '4' }, 'info': { 'internet': 'no',
def __init__(self, opt, agent, num_extra_trial=2, max_turn=50, max_resp_time=180, model_agent_opt=None, world_tag='', agent_timeout_shutdown=180): self.opt = opt self.agent = agent self.turn_idx = 1 self.hit_id = None self.max_turn = max_turn self.num_extra_trial = num_extra_trial self.dialog = [] self.task_type = 'sandbox' if opt['is_sandbox'] else 'live' self.eval_done = False self.chat_done = False self.success = False self.success_attempts = [] self.fail_attempts = [] self.fail_reason = None self.understanding_score = -1 self.understanding_reason = None self.appropriateness_score = -1 self.appropriateness_reason = None self.world_tag = world_tag self.ratings = ['1', '2', '3', '4', '5'] super().__init__(opt, agent) # set up model agent self.model_agents = { # "cambridge": CambridgeBot(), # "sequicity": SequicityBot(), # "RuleBot": RuleBot(), "DQNBot": DQNBot() } # self.model_agent = RuleBot() # self.model_agent = DQNBot() self.model_name = random.choice(list(self.model_agents.keys())) self.model_agent = self.model_agents[self.model_name] print("Bot is loaded") # below are timeout protocols self.max_resp_time = max_resp_time # in secs self.agent_timeout_shutdown = agent_timeout_shutdown # set up personas self.goal = None goal_generator = GoalGenerator(boldify=True) num_goal_trials = 0 while num_goal_trials < 100 and self.goal == None: try: self.goal = goal_generator.get_user_goal() except Exception as e: print(e) num_goal_trials += 1 self.goal_message, _ = goal_generator.build_message(self.goal) self.goal_text = '<ul>' for m in self.goal_message: self.goal_text += '<li>' + m + '</li>' self.goal_text += '</ul>' print(self.goal_text) print(self.goal) self.final_goal = deepcopy(self.goal) self.state = deepcopy(self.goal)