示例#1
0
def generate(total_num=1000, seed=42, output_file='goal.json'):
    random.seed(seed)
    np.random.seed(seed)
    goal_generator = GoalGenerator()
    goals = []
    avg_domains = []
    while len(goals) < total_num:
        goal = goal_generator.get_user_goal()
        # pprint(goal)
        if 'police' in goal['domain_ordering']:
            no_police = list(goal['domain_ordering'])
            no_police.remove('police')
            goal['domain_ordering'] = tuple(no_police)
            del goal['police']
        try:
            message = goal_generator.build_message(goal)[1]
        except:
            continue
        # print(message)
        avg_domains.append(len(goal['domain_ordering']))
        goals.append({
            "goals": [],
            "ori_goals": goal,
            "description": message,
            "timestamp": str(datetime.datetime.now()),
            "ID": len(goals)
        })
    print('avg domains:', np.mean(avg_domains))  # avg domains: 1.846
    json.dump(goals, open(output_file, 'w'), indent=4)
示例#2
0
def test_generate_overlap(total_num=1000, seed=42, output_file='goal.json'):
    train_data = read_zipped_json('../../../data/multiwoz/train.json.zip',
                                  'train.json')
    train_serialized_goals = []
    for d in train_data:
        train_serialized_goals.append(
            extract_slot_combination_from_goal(train_data[d]['goal']))

    test_data = read_zipped_json('../../../data/multiwoz/test.json.zip',
                                 'test.json')
    test_serialized_goals = []
    for d in test_data:
        test_serialized_goals.append(
            extract_slot_combination_from_goal(test_data[d]['goal']))

    overlap = 0
    for serialized_goal in test_serialized_goals:
        if serialized_goal in train_serialized_goals:
            overlap += 1
    print(len(train_serialized_goals), len(test_serialized_goals),
          overlap)  # 8434 1000 430

    random.seed(seed)
    np.random.seed(seed)
    goal_generator = GoalGenerator()
    goals = []
    avg_domains = []
    serialized_goals = []
    while len(goals) < total_num:
        goal = goal_generator.get_user_goal()
        # pprint(goal)
        if 'police' in goal['domain_ordering']:
            no_police = list(goal['domain_ordering'])
            no_police.remove('police')
            goal['domain_ordering'] = tuple(no_police)
            del goal['police']
        try:
            message = goal_generator.build_message(goal)[1]
        except:
            continue
        # print(message)
        avg_domains.append(len(goal['domain_ordering']))
        goals.append({
            "goals": [],
            "ori_goals": goal,
            "description": message,
            "timestamp": str(datetime.datetime.now()),
            "ID": len(goals)
        })
        serialized_goals.append(extract_slot_combination_from_goal(goal))
        if len(serialized_goals) == 1:
            print(serialized_goals)
    overlap = 0
    for serialized_goal in serialized_goals:
        if serialized_goal in train_serialized_goals:
            overlap += 1
    print(len(train_serialized_goals), len(serialized_goals),
          overlap)  # 8434 1000 199
示例#3
0
    def __init__(self,
                 archive_file=DEFAULT_ARCHIVE_FILE,
                 model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/vhus_simulator_multiwoz.zip'):
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
            config = json.load(f)
        manager = UserDataManager()
        voc_goal_size, voc_usr_size, voc_sys_size = manager.get_voc_size()
        self.user = VHUS(config, voc_goal_size, voc_usr_size, voc_sys_size).to(device=DEVICE)
        self.goal_gen = GoalGenerator()
        self.manager = manager
        self.user.eval()

        self.load(archive_file, model_file, config['load'])
    def __init__(self):
        """
        Constructor for User_Policy_Agenda class.
        """
        self.max_turn = 40
        self.max_initiative = 4

        self.goal_generator = GoalGenerator()

        self.__turn = 0
        self.goal = None
        self.agenda = None

        Policy.__init__(self)
    def __init__(self, goal_generator: GoalGenerator):
        """
        create new Goal by random
        Args:
            goal_generator (GoalGenerator): Goal Generator.
        """
        self.domain_goals = goal_generator.get_user_goal()

        self.domains = list(self.domain_goals['domain_ordering'])
        del self.domain_goals['domain_ordering']

        for domain in self.domains:
            if 'reqt' in self.domain_goals[domain].keys():
                self.domain_goals[domain]['reqt'] = {slot: DEF_VAL_UNK for slot in self.domain_goals[domain]['reqt']}

            if 'book' in self.domain_goals[domain].keys():
                self.domain_goals[domain]['booked'] = DEF_VAL_UNK
    # evaluator = MultiWozEvaluator()
    # sess = BiSession(sys_agent=sys_agent, user_agent=user_agent, kb_query=None, evaluator=evaluator)

    # user_policy = UserPolicyAgendaMultiWoz()
    #
    # sys_policy = RuleBasedMultiwozBot()
    #
    # user_nlg = TemplateNLG(is_user=True, mode='manual')
    # sys_nlg = TemplateNLG(is_user=False, mode='manual')
    #
    # dst = RuleDST()
    #
    # user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
    #                    model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
    #
    goal_generator = GoalGenerator()
    # while True:
    #     goal = goal_generator.get_user_goal()
    #     if 'restaurant' in goal['domain_ordering'] and 'hotel' in goal['domain_ordering']:
    #         break
    # # pprint(goal)
    user_goal = {
        'domain_ordering': ('restaurant', 'hotel', 'taxi'),
        'hotel': {
            'book': {
                'day': 'sunday',
                'people': '6',
                'stay': '4'
            },
            'info': {
                'internet': 'no',
示例#7
0
    def __init__(self,
                 opt,
                 agent,
                 num_extra_trial=2,
                 max_turn=50,
                 max_resp_time=180,
                 model_agent_opt=None,
                 world_tag='',
                 agent_timeout_shutdown=180):
        self.opt = opt
        self.agent = agent
        self.turn_idx = 1
        self.hit_id = None
        self.max_turn = max_turn
        self.num_extra_trial = num_extra_trial
        self.dialog = []
        self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
        self.eval_done = False
        self.chat_done = False
        self.success = False
        self.success_attempts = []
        self.fail_attempts = []
        self.fail_reason = None
        self.understanding_score = -1
        self.understanding_reason = None
        self.appropriateness_score = -1
        self.appropriateness_reason = None
        self.world_tag = world_tag
        self.ratings = ['1', '2', '3', '4', '5']
        super().__init__(opt, agent)

        # set up model agent
        self.model_agents = {
            # "cambridge": CambridgeBot(),
            # "sequicity": SequicityBot(),
            # "RuleBot": RuleBot(),
            "DQNBot": DQNBot()
        }
        # self.model_agent = RuleBot()
        # self.model_agent = DQNBot()
        self.model_name = random.choice(list(self.model_agents.keys()))
        self.model_agent = self.model_agents[self.model_name]
        print("Bot is loaded")

        # below are timeout protocols
        self.max_resp_time = max_resp_time  # in secs
        self.agent_timeout_shutdown = agent_timeout_shutdown

        # set up personas
        self.goal = None
        goal_generator = GoalGenerator(boldify=True)
        num_goal_trials = 0
        while num_goal_trials < 100 and self.goal == None:
            try:
                self.goal = goal_generator.get_user_goal()
            except Exception as e:
                print(e)
                num_goal_trials += 1
        self.goal_message, _ = goal_generator.build_message(self.goal)
        self.goal_text = '<ul>'
        for m in self.goal_message:
            self.goal_text += '<li>' + m + '</li>'
        self.goal_text += '</ul>'
        print(self.goal_text)

        print(self.goal)
        self.final_goal = deepcopy(self.goal)
        self.state = deepcopy(self.goal)