Ejemplo n.º 1
0
def generate(total_num=1000, seed=42, output_file='goal.json'):
    random.seed(seed)
    np.random.seed(seed)
    goal_generator = GoalGenerator()
    goals = []
    avg_domains = []
    while len(goals) < total_num:
        goal = goal_generator.get_user_goal()
        # pprint(goal)
        if 'police' in goal['domain_ordering']:
            no_police = list(goal['domain_ordering'])
            no_police.remove('police')
            goal['domain_ordering'] = tuple(no_police)
            del goal['police']
        try:
            message = goal_generator.build_message(goal)[1]
        except:
            continue
        # print(message)
        avg_domains.append(len(goal['domain_ordering']))
        goals.append({
            "goals": [],
            "ori_goals": goal,
            "description": message,
            "timestamp": str(datetime.datetime.now()),
            "ID": len(goals)
        })
    print('avg domains:', np.mean(avg_domains))  # avg domains: 1.846
    json.dump(goals, open(output_file, 'w'), indent=4)
Ejemplo n.º 2
0
def test_generate_overlap(total_num=1000, seed=42, output_file='goal.json'):
    train_data = read_zipped_json('../../../data/multiwoz/train.json.zip',
                                  'train.json')
    train_serialized_goals = []
    for d in train_data:
        train_serialized_goals.append(
            extract_slot_combination_from_goal(train_data[d]['goal']))

    test_data = read_zipped_json('../../../data/multiwoz/test.json.zip',
                                 'test.json')
    test_serialized_goals = []
    for d in test_data:
        test_serialized_goals.append(
            extract_slot_combination_from_goal(test_data[d]['goal']))

    overlap = 0
    for serialized_goal in test_serialized_goals:
        if serialized_goal in train_serialized_goals:
            overlap += 1
    print(len(train_serialized_goals), len(test_serialized_goals),
          overlap)  # 8434 1000 430

    random.seed(seed)
    np.random.seed(seed)
    goal_generator = GoalGenerator()
    goals = []
    avg_domains = []
    serialized_goals = []
    while len(goals) < total_num:
        goal = goal_generator.get_user_goal()
        # pprint(goal)
        if 'police' in goal['domain_ordering']:
            no_police = list(goal['domain_ordering'])
            no_police.remove('police')
            goal['domain_ordering'] = tuple(no_police)
            del goal['police']
        try:
            message = goal_generator.build_message(goal)[1]
        except:
            continue
        # print(message)
        avg_domains.append(len(goal['domain_ordering']))
        goals.append({
            "goals": [],
            "ori_goals": goal,
            "description": message,
            "timestamp": str(datetime.datetime.now()),
            "ID": len(goals)
        })
        serialized_goals.append(extract_slot_combination_from_goal(goal))
        if len(serialized_goals) == 1:
            print(serialized_goals)
    overlap = 0
    for serialized_goal in serialized_goals:
        if serialized_goal in train_serialized_goals:
            overlap += 1
    print(len(train_serialized_goals), len(serialized_goals),
          overlap)  # 8434 1000 199
Ejemplo n.º 3
0
    def __init__(self,
                 opt,
                 agent,
                 num_extra_trial=2,
                 max_turn=50,
                 max_resp_time=180,
                 model_agent_opt=None,
                 world_tag='',
                 agent_timeout_shutdown=180):
        self.opt = opt
        self.agent = agent
        self.turn_idx = 1
        self.hit_id = None
        self.max_turn = max_turn
        self.num_extra_trial = num_extra_trial
        self.dialog = []
        self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
        self.eval_done = False
        self.chat_done = False
        self.success = False
        self.success_attempts = []
        self.fail_attempts = []
        self.fail_reason = None
        self.understanding_score = -1
        self.understanding_reason = None
        self.appropriateness_score = -1
        self.appropriateness_reason = None
        self.world_tag = world_tag
        self.ratings = ['1', '2', '3', '4', '5']
        super().__init__(opt, agent)

        # set up model agent
        self.model_agents = {
            # "cambridge": CambridgeBot(),
            # "sequicity": SequicityBot(),
            # "RuleBot": RuleBot(),
            "DQNBot": DQNBot()
        }
        # self.model_agent = RuleBot()
        # self.model_agent = DQNBot()
        self.model_name = random.choice(list(self.model_agents.keys()))
        self.model_agent = self.model_agents[self.model_name]
        print("Bot is loaded")

        # below are timeout protocols
        self.max_resp_time = max_resp_time  # in secs
        self.agent_timeout_shutdown = agent_timeout_shutdown

        # set up personas
        self.goal = None
        goal_generator = GoalGenerator(boldify=True)
        num_goal_trials = 0
        while num_goal_trials < 100 and self.goal == None:
            try:
                self.goal = goal_generator.get_user_goal()
            except Exception as e:
                print(e)
                num_goal_trials += 1
        self.goal_message, _ = goal_generator.build_message(self.goal)
        self.goal_text = '<ul>'
        for m in self.goal_message:
            self.goal_text += '<li>' + m + '</li>'
        self.goal_text += '</ul>'
        print(self.goal_text)

        print(self.goal)
        self.final_goal = deepcopy(self.goal)
        self.state = deepcopy(self.goal)