def detect(opt, printargs=None, print_parser=None):
    """Checks a task for offensive language.
    """
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    random.seed(42)

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)
    bad = OffensiveLanguageDetector()

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = Timer()
    tot_time = 0

    # Show some example dialogs:
    cnt = 0
    while not world.epoch_done():
        world.parley()
        offensive = False
        for a in world.acts:
            if bad.contains_offensive_language(a.get('text', '')):
                offensive = True
            labels = a.get('labels', a.get('eval_labels', ''))
            for l in labels:
                if bad.contains_offensive_language(l):
                    offensive = True

        if offensive:
            if opt['display_examples']:
                print(world.display() + "\n~~")
            cnt += 1
        if log_time.time() > log_every_n_secs:
            tot_time += log_time.time()
            report = world.report()
            log = {'total': report['total']}
            log['done'] = report['total'] / world.num_examples()
            if log['done'] > 0:
                log['eta'] = int(tot_time / log['done'] - tot_time)
            z = '%.2f' % (100 * log['done'])
            log['done'] = str(z) + '%'
            log['offenses'] = cnt
            print(str(int(tot_time)) + "s elapsed: " + str(log))
            log_time.reset()
    if world.epoch_done():
        print("EPOCH DONE")
    print(
        str(cnt) + " offensive messages found out of " +
        str(world.num_examples()) + " messages.")
    return world.report()
Пример #2
0
def detect(opt, printargs=None, print_parser=None):
    """Checks a task for offensive language.
    """
    if print_parser is not None:
        if print_parser is True and isinstance(opt, ParlaiParser):
            print_parser = opt
        elif print_parser is False:
            print_parser = None
    random.seed(42)

    # Create model and assign it to the specified task
    agent = create_agent(opt, requireModelExists=True)
    world = create_task(opt, agent)
    bad = OffensiveLanguageDetector()

    if print_parser:
        # Show arguments after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()
    log_every_n_secs = opt.get('log_every_n_secs', -1)
    if log_every_n_secs <= 0:
        log_every_n_secs = float('inf')
    log_time = TimeLogger()

    # Show some example dialogs:
    cnt = 0
    while not world.epoch_done():
        world.parley()
        words = []
        for a in world.acts:
            offensive = bad.contains_offensive_language(a.get('text', ''))
            if offensive:
                words.append(offensive)
            labels = a.get('labels', a.get('eval_labels', ''))
            for l in labels:
                offensive = bad.contains_offensive_language(l)
                if offensive:
                    words.append(offensive)
        if len(words) > 0 and opt['display_examples']:
            print(world.display())
            print("[Offensive words detected:]", ', '.join(words))
            print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
        cnt += len(words)
        if log_time.time() > log_every_n_secs:
            report = world.report()
            log = {'offenses': cnt}
            text, log = log_time.log(report['exs'], world.num_examples(), log)
            print(text)

    if world.epoch_done():
        print("EPOCH DONE")
    print(
        str(cnt) + " offensive messages found out of " +
        str(world.num_examples()) + " messages.")
    return world.report()
Пример #3
0
    def __init__(
        self,
        opt,
        agents=None,
        shared=None,
        world_tag='NONE',
        ir_agent=None,
        task='',
        wiki_title_to_passage=None,
    ):
        self.turn_idx = 0
        self.min_turns = opt['min_turns']
        self.max_turns = opt['max_turns']
        self.num_turns = np.random.randint(self.min_turns, self.max_turns) + 1
        self.dialog = []
        self.wizard_eval = 0
        self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
        self.chat_done = False
        self.world_tag = world_tag
        self.max_resp_time = opt['max_resp_time']  # in secs
        self.num_passages_to_retrieve = opt['num_passages_retrieved']
        super().__init__(opt, agents, shared)
        self.agents = sorted(agents,
                             key=lambda x: x.id,
                             reverse=random.random() <= 0.5)
        #  Personas and retriever
        self.persona_generator = self.agents[0].persona_generator
        self.relevant_topics = []
        while not self.relevant_topics:
            self.persona_to_topics = {}
            self.persona_idx, persona_data = self.persona_generator.pop_persona(
            )
            for p in persona_data:
                if p[0] == ' ':
                    p = p[1:]
                if p not in self.persona_to_topics:
                    self.persona_to_topics[p] = []
                    topics = set(self.persona_generator.get_topics(p))
                    for t in topics:
                        self.relevant_topics.append(t + ' ({})'.format(p))
                        self.persona_to_topics[p].append(t)

        self.ir_agent = ir_agent
        self.setup_tokenizer(opt)
        self.chosen_topic = ''
        self.chosen_topic_passage = {}
        self.OLD = OffensiveLanguageDetector()
        # Load the title to passage dictionary
        self.wiki_title_to_passage = wiki_title_to_passage
Пример #4
0
    def __init__(self, opt):
        self.second_resp = opt.get('second_response')
        self.examples_idx_stack_path = os.path.join(
            os.getcwd(), './{}_examples_stack{}.pkl'.format(
                'second_response' if self.second_resp else 'first_response',
                '_sandbox' if opt['is_sandbox'] else ''))
        self.OLD = OffensiveLanguageDetector()
        self.opt = opt
        build_pc(opt)
        build_ic(opt)
        df = 'personality_captions' if not self.second_resp else 'image_chat'
        data_path = os.path.join(self.opt['datapath'], '{}/{}.json')
        self.data = []
        for dt in ['train', 'val', 'test']:
            if self.second_resp and dt == 'val':
                dt = 'valid'
            with open(data_path.format(df, dt)) as f:
                self.data += json.load(f)

        if self.second_resp:
            self.data = [d for d in self.data if len(d['dialog']) > 1]

        if os.path.exists(self.examples_idx_stack_path):
            with open(self.examples_idx_stack_path, 'rb') as handle:
                self.idx_stack = pickle.load(handle)
        else:
            self.idx_stack = []
            self.add_idx_stack()
            self.save_idx_stack()
Пример #5
0
    def __init__(self, opt, agents=None, shared=None):
        # Add passed in agents directly.
        self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
        self.agents = agents
        self.acts = [None] * len(agents)
        self.episodeDone = False
        self.opt = opt
        self.data = []
        self.offensive_lang_detector = OffensiveLanguageDetector()
        self.rand_index = random.randint(0, self.opt["participants"] - 1)

        # read list of local images or links to S3 locations
        self.imgs = [
            "/projects2/ParlAI/data/yfcc_images/1e22a9cf867d718551386b427c3b6d18.jpg",
            "/projects2/ParlAI/data/yfcc_images/96472caea58db27769f1c282e2ac0.jpg",
            "/projects2/ParlAI/data/yfcc_images/f09d8fb76822158de129acb0fef463.jpg",
            "/projects2/ParlAI/data/yfcc_images/6e4ccc739ff44ed11da20ad9892317.jpg",
            "/projects2/ParlAI/data/yfcc_images/e7e1844aa9e67cddc6ffe8804d76e45b.jpg",
            "/projects2/ParlAI/data/yfcc_images/5547b3852afec328a491a696ace99a.jpg",
            "/projects2/ParlAI/data/yfcc_images/b326345ae2b2bd14ebf74aaa31e571a.jpg",
            "/projects2/ParlAI/data/yfcc_images/75a13ebe4be7ab5b3f68f692d7db081.jpg",
            "/projects2/ParlAI/data/yfcc_images/246eea26a3fc2d886be795790a7495.jpg",
            "/projects2/ParlAI/data/yfcc_images/010722aa6d2327deddb4ead5e089ea.jpg"
        ]

        # read list of links from a local file
        self.links = [
            "https://www.youtube.com/watch?v=7gUv0xcFqMk".replace(
                "watch?v=", "embed/"),
            "https://www.youtube.com/watch?v=6vYJyOGKCHE".replace(
                "watch?v=", "embed/"),
            "https://www.youtube.com/watch?v=3SJ0Rd7XU4Y".replace(
                "watch?v=", "embed/")
        ]
Пример #6
0
 def __init__(self, opt, agents=None, shared=None, world_tag='NONE'):
     self.turn_idx = 0
     self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
     self.chat_done = False
     self.world_tag = world_tag
     self.max_resp_time = opt['max_resp_time']  # in secs
     super().__init__(opt, agents, shared)
     self.agents = agents
     self.offensive_lang_detector = OffensiveLanguageDetector()
     self.agent = agents[0]
     self.data = []
     self.exact_match = False
     self.num_images = opt['num_images']
     self.multiple_personality = opt.get('multiple_personality', False)
     self.eval = 0
     self.data_type = opt['task_type']
     self.task_type_title = TASK_TYPE_TO_TITLE[opt['task_type']]
     self.config = TASK_TYPE_TO_CONFIG[opt['task_type']]
Пример #7
0
 def __init__(self, opt, agents=None, shared=None, world_tag='NONE'):
     self.turn_idx = 0
     self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
     self.chat_done = False
     self.world_tag = world_tag
     self.max_resp_time = opt['max_resp_time']  # in secs
     super().__init__(opt, agents, shared)
     self.agents = agents
     self.agent = agents[0]
     self.offensive_lang_detector = OffensiveLanguageDetector()
     self.data = []
     self.exact_match = False
     self.num_images = opt['num_images']
     self.second_resp = opt.get('second_response', False)
     self.config = config_first if not self.second_resp else config_second
     if opt.get('yfcc_path'):
         self.image_path = opt['yfcc_path']
     else:
         self.image_path = os.path.join(opt['datapath'], 'yfcc_images')
Пример #8
0
class MTurkWizardOfWikipediaWorld(MultiAgentDialogWorld):
    """World where two agents have a dialogue; one chats freely, perhaps based
        on a persona, while the other is the 'wizard', who bases his/her
        responses on documents (i.e. sentences) retrieved based on what the
        other agent says.
    """
    def __init__(
        self,
        opt,
        agents=None,
        shared=None,
        world_tag='NONE',
        ir_agent=None,
        task='',
        wiki_title_to_passage=None,
    ):
        self.turn_idx = 0
        self.min_turns = opt['min_turns']
        self.max_turns = opt['max_turns']
        self.num_turns = np.random.randint(self.min_turns, self.max_turns) + 1
        self.dialog = []
        self.wizard_eval = 0
        self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
        self.chat_done = False
        self.world_tag = world_tag
        self.max_resp_time = opt['max_resp_time']  # in secs
        self.num_passages_to_retrieve = opt['num_passages_retrieved']
        super().__init__(opt, agents, shared)
        self.agents = sorted(agents,
                             key=lambda x: x.id,
                             reverse=random.random() <= 0.5)
        #  Personas and retriever
        self.persona_generator = self.agents[0].persona_generator
        self.relevant_topics = []
        while not self.relevant_topics:
            self.persona_to_topics = {}
            self.persona_idx, persona_data = self.persona_generator.pop_persona(
            )
            for p in persona_data:
                if p[0] == ' ':
                    p = p[1:]
                if p not in self.persona_to_topics:
                    self.persona_to_topics[p] = []
                    topics = set(self.persona_generator.get_topics(p))
                    for t in topics:
                        self.relevant_topics.append(t + ' ({})'.format(p))
                        self.persona_to_topics[p].append(t)

        self.ir_agent = ir_agent
        self.setup_tokenizer(opt)
        self.chosen_topic = ''
        self.chosen_topic_passage = {}
        self.OLD = OffensiveLanguageDetector()
        # Load the title to passage dictionary
        self.wiki_title_to_passage = wiki_title_to_passage

    def episode_done(self):
        return self.chat_done

    def setup_tokenizer(self, opt):
        try:
            import nltk
        except ImportError:
            raise ImportError('Please install nltk (e.g. pip install nltk).')
        # nltk-specific setup
        st_path = 'tokenizers/punkt/{0}.pickle'.format(opt['dict_language'])
        try:
            self.sent_tok = nltk.data.load(st_path)
        except LookupError:
            nltk.download('punkt')
            self.sent_tok = nltk.data.load(st_path)

    def sufficient_overlap(self, text, sent_dict):
        text_list = [
            w[:4] for w in split_tokenize(text.lower()) if w not in STOPWORDS
        ]
        for _, sentence in sent_dict.items():
            sentence_list = [
                w[:4] for w in split_tokenize(sentence.lower())
                if w not in STOPWORDS
            ]
            if len(set(text_list).intersection(
                    set(sentence_list))) >= self.opt.get(
                        'word_overlap_threshold', 2):
                return True
        return False

    def parley(self):
        """Each agent acts; when the APPRENTICE says something, the
        WIZARD is given retrieved documents based on the text response"""
        self.turn_idx += 1

        # Initial Message Value
        control_msg = {'episode_done': False}
        control_msg['id'] = 'SYSTEM'

        print(self.world_tag + ' is at turn {}...'.format(self.turn_idx))
        '''First Turn: We give the first agent the list of topics to choose from
        '''
        if self.turn_idx == 1:
            for idx, agent in enumerate(self.agents):
                '''If we are giving the persona, do that :)'''
                control_msg['text'] = self.get_instruction(tag='start',
                                                           agent_id=agent.id)
                if agent.id == WIZARD:
                    control_msg['description'] = config['wizard_onboarding']
                else:
                    control_msg['description'] = config[
                        'apprentice_onboarding']
                agent.observe(validate(control_msg))
                if idx == 0:
                    time.sleep(3)
            '''Send First Person the list of relevant topics'''
            self.agents[0].observe(
                validate({
                    'id': 'SYSTEM',
                    'text': PICK_TOPIC_MSG,
                    'relevant_topics': self.relevant_topics,
                }))

            topic_act = self.agents[0].act(timeout=self.max_resp_time)
            timed_out = self.check_timeout(topic_act)
            if not timed_out:
                if self.agents[0].id == APPRENTICE:
                    pick_msg = AFTER_PICK_TOPIC_MSG
                else:
                    pick_msg = AFTER_PICK_TOPIC_WIZARD_MSG
                self.agents[0].observe({'id': 'SYSTEM', 'text': pick_msg})
            self.chosen_topic = topic_act['text']
            '''Now, send the wiki page for the chosen topic to the wizard'''
            for idx, agent in enumerate(self.agents):
                if agent.id == WIZARD:
                    passage = self.wiki_title_to_passage.get(
                        self.chosen_topic, '')
                    if passage == '':
                        break
                    split = passage.split('\n')
                    title = split[0]
                    split = self.sent_tok.tokenize(" ".join(split[1:]))
                    split[0] = split[0][1:]
                    sentences = []
                    for sent in split:
                        if len(sent) > 1:
                            sentences.append(sent)
                            if len(" ".join(sentences)) > MAX_DOC_LEN * 2:
                                break
                    msg_text = AFTER_PARTNER_PICK_TOPIC_WIZARD_MSG if idx == 1 else ""
                    control_msg['text'] = msg_text
                    control_msg['chosen_topic_passages'] = [[title, sentences]]
                    agent.observe(validate(control_msg))
                    self.chosen_topic_passage = {
                        'topic': self.chosen_topic,
                        'full_passage': passage,
                        'shown_passage': sentences,
                    }
        '''If we get to the min turns, inform turker that they can end if they
           want
        '''
        if self.turn_idx == self.num_turns + 1:
            for agent in self.agents:
                control_msg['text'] = self.get_instruction(
                    tag='exceed_min_turns')
                control_msg['exceed_min_turns'] = True
                agent.observe(validate(control_msg))
        '''Otherwise, we proceed accordingly'''
        acts = self.acts
        for idx, agent in enumerate(self.agents):
            # Increase response time for wizard
            max_response_time = self.max_resp_time * (1 if agent.id
                                                      == APPRENTICE else 1.5)
            acts[idx] = agent.act(timeout=max_response_time)
            self.check_timeout(acts[idx])

            # If chat ends
            if acts[idx]['episode_done']:
                self.chat_done = True
                for ag in self.agents:
                    if ag != agent and ag.some_agent_disconnected:
                        control_msg['text'] = UNEXPECTED_DISCONNECTION_MSG
                        ag.observe(validate(control_msg))
                        return
                if self.turn_idx > self.num_turns:
                    for ag in self.agents:
                        ag.observe(validate(acts[idx]))
                        '''Have Apprentice Agent Eval Wizard Agent'''
                        if ag.id == APPRENTICE:
                            control_msg['text'] = EVAL_WIZARD_MSG
                            control_msg['wizard_eval'] = True
                            ag.observe(validate(control_msg))
                            act = ag.act(timeout=self.max_resp_time)
                            self.check_timeout(act)
                            try:
                                w_ev = int(act['text'])
                                w_ev = max(w_ev, 1) if w_ev <= 0 else min(
                                    w_ev, 5)
                            except ValueError:  # If there is a disconnect here
                                w_ev = -1
                            self.wizard_eval = w_ev
                        control_msg['text'] = CHAT_ENDED_MSG
                        ag.observe(validate(control_msg))
                return
            '''Set up msg info dict to save in dialog'''
            msg_info = {
                'speaker':
                '{}_{}'.format(idx, agent.id),
                'text':
                acts[idx]['text'],
                'turn':
                self.turn_idx,
                'time':
                time.time(),
                'offensive':
                self.OLD.contains_offensive_language(acts[idx]['text']),
            }
            '''Get clicked passages and checked sentences from Wizard'''
            if 'clicked_passages' in acts[idx]:
                msg_info['clicked_passages'] = acts[idx]['clicked_passages']
                checked_sents = {}
                for k, v in acts[idx]['checked_sentences'].items():
                    if k == 'no_passages_used':
                        checked_sents[k] = v
                    else:
                        split = k.split('_')
                        person = split[0]
                        topic_idx = split[1]
                        sent_idx = split[2]
                        if person == 'partner':
                            sub_passages = [
                                p.split('\n')[0]
                                for p in self.dialog[-1]['full_passages']
                            ]
                            topic = sub_passages[int(topic_idx)]
                        elif person == 'self':
                            sub_passages = [
                                p.split('\n')[0]
                                for p in self.dialog[-2]['full_passages']
                            ]
                            topic = sub_passages[int(topic_idx)]
                        else:
                            topic = self.chosen_topic
                        cs_key = '_'.join(
                            [person, '_'.join(topic.split(' ')), sent_idx])
                        checked_sents[cs_key] = v
                msg_info['checked_sentence'] = checked_sents
                msg_info['checked_passage'] = acts[idx]['checked_passages']
                msg_info['good_message'] = self.sufficient_overlap(
                    msg_info['text'], msg_info['checked_sentence'])
            '''Retrieve Passages'''
            ir_passages = self.retrieve_passages(copy.deepcopy(acts[idx]))
            passages = self.format_passages(ir_passages)
            msg_info['full_passages'] = ir_passages
            msg_info['shown_passages'] = passages

            if agent.id == WIZARD:
                '''Give Wizard the Relevant Passages'''
                control_msg['text'] = ''
                control_msg['self_retrieved_passages'] = passages
                agent.observe(validate(control_msg))

            self.dialog.append(msg_info)

            for other_agent in self.agents:
                if other_agent != agent:
                    other_agent.observe(validate(acts[idx]))
                    if other_agent.id == WIZARD:
                        control_msg[
                            'text'] = PARTNER_RETRIEVED_PASSAGES_INST_MSG
                        control_msg['partner_retrieved_passages'] = passages
                        other_agent.observe(validate(control_msg))

    def format_passages(self, ir_passages, max_length=MAX_DOC_LEN):
        passages = []
        if len(ir_passages) == 1:  # Didn't receive any passages
            passages.append(['No Passages Retrieved', []])
        else:
            for passage in ir_passages:
                split = passage.split('\n')
                title = split[0]
                split = self.sent_tok.tokenize(" ".join(split[1:]))
                split[0] = split[0][1:]
                sentences = []
                for sent in split:
                    if len(sent) > 1:
                        sentences.append(sent)
                        if len(" ".join(sentences)) > max_length:
                            break
                passages.append([title, sentences])
        return passages

    def retrieve_passages(self, act, num_passages=None):
        if not num_passages:
            num_passages = self.num_passages_to_retrieve
        self.ir_agent.observe(act)
        action = self.ir_agent.act()
        passages = action.get('text_candidates', [action.get('text', "")])
        return passages[:min(len(passages), num_passages)]

    def get_instruction(self, agent_id=None, tag='first'):
        if tag == 'start':
            start_msg = WIZARD_START_MSG if agent_id == WIZARD else APPRENTICE_START_MSG
            return start_msg.format(self.num_turns)
        if tag == 'timeout':
            return TIMEOUT_MSG
        if tag == 'exceed_min_turns':
            return EXCEED_MIN_TURNS_MSG.format(self.num_turns)

    def check_timeout(self, act):
        if act['text'] == '[TIMEOUT]' and act['episode_done']:
            control_msg = {'episode_done': True}
            control_msg['id'] = 'SYSTEM'
            control_msg['text'] = self.get_instruction(tag='timeout')
            for ag in self.agents:
                if ag.id != act['id']:
                    ag.observe(validate(control_msg))
            self.chat_done = True
            return True
        else:
            return False

    def reset_random(self):
        self.num_turns = np.random.randint(self.min_turns, self.max_turns) + 1

    def save_data(self):
        # save persona_idx_stack
        convo_finished = self.turn_idx >= self.num_turns + 1
        for ag in self.agents:
            if (ag.hit_is_abandoned or ag.hit_is_returned or ag.disconnected
                    or ag.hit_is_expired):
                convo_finished = False
        if not convo_finished:
            self.persona_generator.push_persona(self.persona_idx)
            print("\n**Push persona {} back to stack. **\n".format(
                self.persona_idx))
        self.agents[0].persona_generator.save_idx_stack()

        data_path = self.opt['data_path']
        if not os.path.exists(data_path):
            os.makedirs(data_path)
        self.convo_finished = convo_finished
        self.wizard_worker = ''
        if convo_finished:
            filename = os.path.join(
                data_path,
                '{}_{}_{}.pkl'.format(
                    time.strftime("%Y%m%d-%H%M%S"),
                    np.random.randint(0, 1000),
                    self.task_type,
                ),
            )
            self.good_wiz, self.wizard_worker = self.check_wizard_quality()
        else:
            filename = os.path.join(
                data_path,
                '{}_{}_{}_incomplete.pkl'.format(
                    time.strftime("%Y%m%d-%H%M%S"),
                    np.random.randint(0, 1000),
                    self.task_type,
                ),
            )
            self.good_wiz = True
        pickle.dump(
            {
                'persona':
                self.persona_to_topics,
                'relevant_topics':
                self.relevant_topics,
                'chosen_topic_passage':
                self.chosen_topic_passage,
                'dialog':
                self.dialog,
                'speaker_with_persona':
                self.agents[0].worker_id,
                'workers': [ag.worker_id for ag in self.agents],
                'n_turn':
                self.num_turns,
                'hit_ids': [ag.hit_id for ag in self.agents],
                'assignment_ids': [ag.assignment_id for ag in self.agents],
                'wizard_eval':
                self.wizard_eval,
                'chosen_topic':
                self.chosen_topic,
                'wizard_good':
                convo_finished and self.good_wiz,
                'good_wizard_worker':
                self.wizard_worker if self.good_wiz else '',
                'bad_wizard_worker':
                self.wizard_worker if not self.good_wiz else '',
            },
            open(filename, 'wb'),
        )
        print('{}: Data successfully saved at {}.'.format(
            self.world_tag, filename))

    def check_wizard_quality(self):
        '''Determines whether to soft-block this turker or not
           Only called if the conversation finishes
           Returns True if the Wizard is good
        '''
        num_good_sents = len(
            list(
                filter(
                    lambda info: 'good_message' in info and info['good_message'
                                                                 ],
                    self.dialog,
                )))
        wizard_worker = [w for w in self.agents if w.id == WIZARD][0].worker_id
        data_path = self.opt['current_working_dir']
        bad_wizards = os.path.join(data_path, 'bad_wizards.txt')
        good_wizards = os.path.join(data_path, 'good_wizards.txt')
        if num_good_sents < self.opt['num_good_sentence_threshold']:
            if not self.opt['is_sandbox']:
                with open(bad_wizards, 'a') as f:
                    f.write(wizard_worker + '\n')
            return False, wizard_worker
        else:
            if not self.opt['is_sandbox']:
                with open(good_wizards, 'a') as f:
                    f.write(wizard_worker + '\n')
            return True, wizard_worker

    def review_work(self):
        global review_agent

        def review_agent(ag):
            role = ag.id
            for d in self.dialog:
                if role in d['speaker']:
                    if d['offensive']:
                        ag.reject_work(reason='Your HIT has been rejected '
                                       'because we detected offensive '
                                       'language in your submission.')

        Parallel(n_jobs=len(self.agents),
                 backend='threading')(delayed(review_agent)(agent)
                                      for agent in self.agents)

    def shutdown(self):
        """Shutdown all mturk agents in parallel, otherwise if one mturk agent
        is disconnected then it could prevent other mturk agents from
        completing.
        """
        global shutdown_agent

        def shutdown_agent(agent):
            agent.shutdown()

        Parallel(n_jobs=len(self.agents),
                 backend='threading')(delayed(shutdown_agent)(agent)
                                      for agent in self.agents)
Пример #9
0
class MTurkPersonalityCaptionsWorld(MultiAgentDialogWorld):
    """World an agent observes ten images, with ten different personalities,
        and writes engaging comments about them
    """
    def __init__(self, opt, agents=None, shared=None, world_tag='NONE'):
        self.turn_idx = 0
        self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
        self.chat_done = False
        self.world_tag = world_tag
        self.max_resp_time = opt['max_resp_time']  # in secs
        super().__init__(opt, agents, shared)
        self.agents = agents
        self.offensive_lang_detector = OffensiveLanguageDetector()
        self.agent = agents[0]
        self.data = []
        self.exact_match = False
        self.num_images = opt['num_images']
        self.multiple_personality = opt.get('multiple_personality', False)
        self.eval = 0
        self.data_type = opt['task_type']
        self.task_type_title = TASK_TYPE_TO_TITLE[opt['task_type']]
        self.config = TASK_TYPE_TO_CONFIG[opt['task_type']]

    def episode_done(self):
        return self.chat_done

    def parley(self):
        """
            COMMENTER is given an image, and is told to give a comment for
            the image
        """
        # Initial Message Value
        control_msg = {'episode_done': False}
        control_msg['id'] = 'SYSTEM'

        # First, we give COMMENTER their personality instructions, and image
        while self.turn_idx < self.num_images:
            print(self.world_tag + ' is at turn {}...'.format(self.turn_idx))
            # Send personality + image to turker
            if not self.multiple_personality:
                pers_tup = self.agent.personality_generator.pop_personality()
                self.pers_idx, personality = pers_tup
                img_tup = self.agent.image_generator.pop_image()
                self.image_hash, image_path = img_tup
                img = load_image(image_path)
            else:
                pair_tup = self.agent.personality_and_image_generator.pop_pair(
                )
                self.pair_idx, personality, self.image_hash, image_path = pair_tup
                img = load_image(image_path)
            buffered = BytesIO()
            img.save(buffered, format='JPEG')
            encoded = str(
                base64.b64encode(buffered.getvalue()).decode('ascii'))
            control_msg['image'] = encoded
            if self.data_type == 'personality':
                personality_text = '<b><span style="color:blue">' \
                                   '{}\n</span></b>'.format(personality.strip())
                control_msg['personality_text'] = personality_text
            control_msg['text'] = self.get_instruction(tag='start',
                                                       agent_id=self.agent.id,
                                                       turn_num=self.turn_idx +
                                                       1)
            control_msg['description'] = self.config['task_description']
            control_msg['task_type_title'] = self.task_type_title
            self.agent.observe(validate(control_msg))
            time.sleep(1)

            # Collect comment from turker
            offensive_counter = 0
            while offensive_counter < 3:
                idx = 0
                acts = self.acts
                acts[idx] = self.agent.act(timeout=self.max_resp_time)
                agent_left = self.check_timeout(acts[idx])
                if agent_left:
                    break
                comment = acts[idx]['text']
                offensive = self.offensive_lang_detector.contains_offensive_language(
                    comment)
                if offensive:
                    # Tell Turker to not be offensive!
                    offensive_msg = {
                        'id': 'SYSTEM',
                        'text': OFFENSIVE_MSG,
                    }
                    self.agent.observe(validate(offensive_msg))
                    offensive_counter += 1
                else:
                    break

            if self.chat_done:
                break
            self.data.append({
                'comment': comment,
                'personality': personality,
                'image_hash': self.image_hash,
                'image_path': image_path,
                'contains_offensive_language': offensive,
            })
            self.turn_idx += 1

        if self.turn_idx == self.num_images:
            control_msg['text'] = CHAT_ENDED_MSG.format(self.num_images)
            control_msg['eval'] = True
            self.agent.observe(validate(control_msg))
            act = self.agent.act(timeout=self.max_resp_time)
            self.check_timeout(act)
            try:
                ev_val = int(act['eval'])
                ev_text = act['text']
            except BaseException:  # If there is a disconnect here
                ev_val, ev_text = (-1, 'Eval not received')
            self.eval = {ev_val: ev_text}
        self.chat_done = True
        return

    def get_instruction(self, agent_id=None, tag='first', turn_num=0):
        if tag == 'start':
            return START_MSGS[self.data_type].format(turn_num)
        if tag == 'timeout':
            return TIMEOUT_MSG

    def check_timeout(self, act):
        if act['text'] == '[TIMEOUT]' and act['episode_done']:
            control_msg = {'episode_done': True}
            control_msg['id'] = 'SYSTEM'
            control_msg['text'] = self.get_instruction(tag='timeout')
            for ag in self.agents:
                if ag.id != act['id']:
                    ag.observe(validate(control_msg))
            self.chat_done = True
            return True
        elif act['text'] == '[DISCONNECT]':
            self.chat_done = True
            return True
        else:
            return False

    def save_data(self):
        convo_finished = True
        for ag in self.agents:
            if (ag.hit_is_abandoned or ag.hit_is_returned or ag.disconnected
                    or ag.hit_is_expired):
                convo_finished = False
        if not convo_finished:
            if not self.multiple_personality:
                ag.personality_generator.push_personality(self.pers_idx)
                ag.image_generator.push_image(self.image_hash)
                print('\n**Push personality {} back to stack. **\n'.format(
                    self.pers_idx))
                print('\n**Push image {} back to stack. **\n'.format(
                    self.image_hash))
            else:
                ag.personality_and_image_generator.push_pair(self.pair_idx)
                print('\n**Push pair {} back to stack. **\n'.format(
                    self.pair_idx))
        self.agents[0].personality_generator.save_idx_stack()
        self.agents[0].image_generator.save_idx_stack()
        self.agents[0].personality_and_image_generator.save_idx_stack()
        data_path = self.opt['data_path']
        if not os.path.exists(data_path):
            os.makedirs(data_path)
        if convo_finished:
            filename = os.path.join(
                data_path,
                '{}_{}_{}.pkl'.format(time.strftime('%Y%m%d-%H%M%S'),
                                      np.random.randint(0, 1000),
                                      self.task_type))
        else:
            filename = os.path.join(
                data_path, '{}_{}_{}_incomplete.pkl'.format(
                    time.strftime('%Y%m%d-%H%M%S'), np.random.randint(0, 1000),
                    self.task_type))
        comments = [d['comment'] for d in self.data]

        if len(comments) >= 2:
            c = comments[0]
            if _exact_match(c, comments[1:]):
                self.exact_match = True

        pickle.dump(
            {
                'data': self.data,
                'worker': self.agents[0].worker_id,
                'hit_id': self.agents[0].hit_id,
                'assignment_id': self.agents[0].assignment_id,
                'exact_match': self.exact_match,
                'task_eval': self.eval
            }, open(filename, 'wb'))
        print('{}: Data successfully saved at {}.'.format(
            self.world_tag, filename))

    def review_work(self):
        global review_agent

        def review_agent(ag):
            contains_offense = any(d['contains_offensive_language']
                                   for d in self.data)
            if contains_offense:
                ag.reject_work(reason='We have rejected this HIT because at '
                               'least one of your comments contains '
                               'offensive language')
                print('Rejected work for agent {} for '
                      'offensive language'.format(ag.worker_id))
            elif self.exact_match:
                ag.reject_work(reason='We have rejected this HIT because all '
                               'of your comments are the exact same')
                print('Rejected work for agent {} for '
                      'same comments'.format(ag.worker_id))

        Parallel(n_jobs=len(self.agents),
                 backend='threading')(delayed(review_agent)(agent)
                                      for agent in self.agents)

    def shutdown(self):
        """Shutdown all mturk agents in parallel, otherwise if one mturk agent
        is disconnected then it could prevent other mturk agents from
        completing.
        """
        global shutdown_agent

        def shutdown_agent(agent):
            agent.shutdown()

        Parallel(n_jobs=len(self.agents),
                 backend='threading')(delayed(shutdown_agent)(agent)
                                      for agent in self.agents)
Пример #10
0
class MTurkImageChatWorld(MultiAgentDialogWorld):
    """World where an agent observes 5 images and 5 comments,
       with 5 different personalities,
       and writes engaging responses to the comments
    """
    def __init__(self, opt, agents=None, shared=None, world_tag='NONE'):
        self.turn_idx = 0
        self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
        self.chat_done = False
        self.world_tag = world_tag
        self.max_resp_time = opt['max_resp_time']  # in secs
        super().__init__(opt, agents, shared)
        self.agents = agents
        self.agent = agents[0]
        self.offensive_lang_detector = OffensiveLanguageDetector()
        self.data = []
        self.exact_match = False
        self.num_images = opt['num_images']
        self.second_resp = opt.get('second_response', False)
        self.config = config_first if not self.second_resp else config_second
        if opt.get('yfcc_path'):
            self.image_path = opt['yfcc_path']
        else:
            self.image_path = os.path.join(opt['datapath'], 'yfcc_images')

    def episode_done(self):
        return self.chat_done

    def parley(self):
        """RESPONDER is given an image and a comment, and is told to give a
        response for to the comment"""
        # Initial Message Value
        control_msg = {'episode_done': False}
        control_msg['id'] = 'SYSTEM'
        '''We only have to worry about 1 agent'''
        agent = self.agents[0]
        '''First, we give RESPONDER their personality instructions, and image
        '''
        while self.turn_idx < self.num_images:
            print(self.world_tag + ' is at turn {}...'.format(self.turn_idx))
            # Send personality + image + comment to turker

            self.example_num, example = self.agent.example_generator.pop_example(
            )
            control_msg['text'] = self.get_instruction(tag='start',
                                                       agent_id=agent.id,
                                                       turn_num=self.turn_idx +
                                                       1)

            if self.second_resp:
                control_msg['comment_text'] = '<b><span style="color:red">' \
                                              '{}\n</span></b>'.format(
                                              example['dialog'][0][1].strip())
                control_msg['response_text'] = '<b><span style="color:blue">' \
                                               '{}\n</span></b>'.format(
                                               example['dialog'][1][1].strip())
            else:
                control_msg['comment_text'] = '<b><span style="color:red">' \
                                              '{}\n</span></b>'.format(
                                              example['comment'].strip())

            img = load_image(
                os.path.join(self.image_path,
                             '{}.jpg'.format(example['image_hash'])))
            buffered = BytesIO()
            img.save(buffered, format='JPEG')
            encoded = str(
                base64.b64encode(buffered.getvalue()).decode('ascii'))
            control_msg['image'] = encoded
            if self.second_resp:
                self.pers_idx, personality = (-1, example['dialog'][0][0])
            else:
                pers_tup = self.agent.personality_generator.pop_personality()
                self.pers_idx, personality = pers_tup
            personality_text = '<b><span style="color:{}">' \
                               '{}\n</span></b>'.format(
                                    'blue' if not self.second_resp else 'red',
                                    personality.strip())
            control_msg['personality_text'] = personality_text
            control_msg['description'] = self.config['task_description']
            agent.observe(validate(control_msg))
            time.sleep(1)

            # Collect comment from turker
            offensive_counter = 0
            while offensive_counter < 3:
                idx = 0
                acts = self.acts
                acts[idx] = agent.act(timeout=self.max_resp_time)
                agent_left = self.check_timeout(acts[idx])
                if agent_left:
                    break
                response = acts[idx]['text']
                offensive = self.offensive_lang_detector.contains_offensive_language(
                    response)
                if offensive:
                    # Tell Turker to not be offensive!
                    offensive_counter += 1
                    if offensive_counter == 3:
                        break
                    offensive_msg = {
                        'id': 'SYSTEM',
                        'text': OFFENSIVE_MSG,
                    }
                    agent.observe(validate(offensive_msg))
                else:
                    break

            if self.chat_done:
                break
            ex_to_save = example.copy()
            key = 'second_response' if self.second_resp else 'first_response'
            ex_to_save[key] = response
            ex_to_save['{}_personality'.format(key)] = personality
            ex_to_save['contains_offensive_language'] = offensive
            self.data.append(ex_to_save)
            self.turn_idx += 1

        if self.turn_idx == self.num_images:
            control_msg['text'] = CHAT_ENDED_MSG.format(self.num_images)
            agent.observe(validate(control_msg))
        self.chat_done = True
        return

    def get_instruction(self, agent_id=None, tag='first', turn_num=0):
        if tag == 'start':
            start_msg = START_MSG if not self.second_resp else START_MSG_SECOND_RESP
            return start_msg.format(turn_num)
        if tag == 'timeout':
            return TIMEOUT_MSG

    def check_timeout(self, act):
        if act['text'] == '[TIMEOUT]' and act['episode_done']:
            control_msg = {'episode_done': True}
            control_msg['id'] = 'SYSTEM'
            control_msg['text'] = self.get_instruction(tag='timeout')
            for ag in self.agents:
                if ag.id != act['id']:
                    ag.observe(validate(control_msg))
            self.chat_done = True
            return True
        elif act['text'] == '[DISCONNECT]':
            self.chat_done = True
            return True
        else:
            return False

    def save_data(self):
        convo_finished = True
        for ag in self.agents:
            if (ag.hit_is_abandoned or ag.hit_is_returned or ag.disconnected
                    or ag.hit_is_expired):
                convo_finished = False
        if not convo_finished:
            if not self.second_resp:
                ag.personality_generator.push_personality(self.pers_idx)
            ag.example_generator.push_example(self.example_num)
            print('\n**Push personality {} back to stack. **\n'.format(
                self.pers_idx))
            print('\n**Push image {} back to stack. **\n'.format(
                self.example_num))
        if not self.second_resp:
            self.agents[0].personality_generator.save_idx_stack()
        self.agents[0].example_generator.save_idx_stack()
        data_path = self.opt['data_path']
        if not os.path.exists(data_path):
            os.makedirs(data_path)
        if convo_finished:
            filename = os.path.join(
                data_path,
                '{}_{}_{}.pkl'.format(time.strftime('%Y%m%d-%H%M%S'),
                                      np.random.randint(0, 1000),
                                      self.task_type))
        else:
            filename = os.path.join(
                data_path, '{}_{}_{}_incomplete.pkl'.format(
                    time.strftime('%Y%m%d-%H%M%S'), np.random.randint(0, 1000),
                    self.task_type))
        key = 'second_response' if self.second_resp else 'first_response'
        responses = [d[key] for d in self.data]

        if len(responses) >= 2:
            c = responses[0]
            if _exact_match(c, responses[1:]):
                self.exact_match = True
        data_to_save = [
            d for d in self.data if d['contains_offensive_language'] is None
        ]
        pickle.dump(
            {
                'data': data_to_save,
                'worker': self.agents[0].worker_id,
                'hit_id': self.agents[0].hit_id,
                'assignment_id': self.agents[0].assignment_id,
                'exact_match': self.exact_match
            }, open(filename, 'wb'))
        print('{}: Data successfully saved at {}.'.format(
            self.world_tag, filename))

    def review_work(self):
        global review_agent

        def review_agent(ag):
            contains_offense = any(d['contains_offensive_language']
                                   for d in self.data)
            if contains_offense:
                ag.reject_work(reason='We have rejected this HIT because at '
                               'least one of your comments '
                               'contains offensive language')
                print(
                    'Rejected work for agent {} for offensive language'.format(
                        ag.worker_id))
            elif self.exact_match:
                ag.reject_work(reason='We have rejected this HIT because '
                               'all of your comments are the exact same')
                print('Rejected work for agent {} for same comments'.format(
                    ag.worker_id))

        Parallel(n_jobs=len(self.agents),
                 backend='threading')(delayed(review_agent)(agent)
                                      for agent in self.agents)

    def shutdown(self):
        """Shutdown all mturk agents in parallel, otherwise if one mturk agent
        is disconnected then it could prevent other mturk agents from
        completing.
        """
        global shutdown_agent

        def shutdown_agent(agent):
            agent.shutdown()

        Parallel(n_jobs=len(self.agents),
                 backend='threading')(delayed(shutdown_agent)(agent)
                                      for agent in self.agents)