def detect(opt, printargs=None, print_parser=None): """Checks a task for offensive language. """ if print_parser is not None: if print_parser is True and isinstance(opt, ParlaiParser): print_parser = opt elif print_parser is False: print_parser = None random.seed(42) # Create model and assign it to the specified task agent = create_agent(opt, requireModelExists=True) world = create_task(opt, agent) bad = OffensiveLanguageDetector() if print_parser: # Show arguments after loading model print_parser.opt = agent.opt print_parser.print_args() log_every_n_secs = opt.get('log_every_n_secs', -1) if log_every_n_secs <= 0: log_every_n_secs = float('inf') log_time = Timer() tot_time = 0 # Show some example dialogs: cnt = 0 while not world.epoch_done(): world.parley() offensive = False for a in world.acts: if bad.contains_offensive_language(a.get('text', '')): offensive = True labels = a.get('labels', a.get('eval_labels', '')) for l in labels: if bad.contains_offensive_language(l): offensive = True if offensive: if opt['display_examples']: print(world.display() + "\n~~") cnt += 1 if log_time.time() > log_every_n_secs: tot_time += log_time.time() report = world.report() log = {'total': report['total']} log['done'] = report['total'] / world.num_examples() if log['done'] > 0: log['eta'] = int(tot_time / log['done'] - tot_time) z = '%.2f' % (100 * log['done']) log['done'] = str(z) + '%' log['offenses'] = cnt print(str(int(tot_time)) + "s elapsed: " + str(log)) log_time.reset() if world.epoch_done(): print("EPOCH DONE") print( str(cnt) + " offensive messages found out of " + str(world.num_examples()) + " messages.") return world.report()
def detect(opt, printargs=None, print_parser=None): """Checks a task for offensive language. """ if print_parser is not None: if print_parser is True and isinstance(opt, ParlaiParser): print_parser = opt elif print_parser is False: print_parser = None random.seed(42) # Create model and assign it to the specified task agent = create_agent(opt, requireModelExists=True) world = create_task(opt, agent) bad = OffensiveLanguageDetector() if print_parser: # Show arguments after loading model print_parser.opt = agent.opt print_parser.print_args() log_every_n_secs = opt.get('log_every_n_secs', -1) if log_every_n_secs <= 0: log_every_n_secs = float('inf') log_time = TimeLogger() # Show some example dialogs: cnt = 0 while not world.epoch_done(): world.parley() words = [] for a in world.acts: offensive = bad.contains_offensive_language(a.get('text', '')) if offensive: words.append(offensive) labels = a.get('labels', a.get('eval_labels', '')) for l in labels: offensive = bad.contains_offensive_language(l) if offensive: words.append(offensive) if len(words) > 0 and opt['display_examples']: print(world.display()) print("[Offensive words detected:]", ', '.join(words)) print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n") cnt += len(words) if log_time.time() > log_every_n_secs: report = world.report() log = {'offenses': cnt} text, log = log_time.log(report['exs'], world.num_examples(), log) print(text) if world.epoch_done(): print("EPOCH DONE") print( str(cnt) + " offensive messages found out of " + str(world.num_examples()) + " messages.") return world.report()
def __init__( self, opt, agents=None, shared=None, world_tag='NONE', ir_agent=None, task='', wiki_title_to_passage=None, ): self.turn_idx = 0 self.min_turns = opt['min_turns'] self.max_turns = opt['max_turns'] self.num_turns = np.random.randint(self.min_turns, self.max_turns) + 1 self.dialog = [] self.wizard_eval = 0 self.task_type = 'sandbox' if opt['is_sandbox'] else 'live' self.chat_done = False self.world_tag = world_tag self.max_resp_time = opt['max_resp_time'] # in secs self.num_passages_to_retrieve = opt['num_passages_retrieved'] super().__init__(opt, agents, shared) self.agents = sorted(agents, key=lambda x: x.id, reverse=random.random() <= 0.5) # Personas and retriever self.persona_generator = self.agents[0].persona_generator self.relevant_topics = [] while not self.relevant_topics: self.persona_to_topics = {} self.persona_idx, persona_data = self.persona_generator.pop_persona( ) for p in persona_data: if p[0] == ' ': p = p[1:] if p not in self.persona_to_topics: self.persona_to_topics[p] = [] topics = set(self.persona_generator.get_topics(p)) for t in topics: self.relevant_topics.append(t + ' ({})'.format(p)) self.persona_to_topics[p].append(t) self.ir_agent = ir_agent self.setup_tokenizer(opt) self.chosen_topic = '' self.chosen_topic_passage = {} self.OLD = OffensiveLanguageDetector() # Load the title to passage dictionary self.wiki_title_to_passage = wiki_title_to_passage
def __init__(self, opt): self.second_resp = opt.get('second_response') self.examples_idx_stack_path = os.path.join( os.getcwd(), './{}_examples_stack{}.pkl'.format( 'second_response' if self.second_resp else 'first_response', '_sandbox' if opt['is_sandbox'] else '')) self.OLD = OffensiveLanguageDetector() self.opt = opt build_pc(opt) build_ic(opt) df = 'personality_captions' if not self.second_resp else 'image_chat' data_path = os.path.join(self.opt['datapath'], '{}/{}.json') self.data = [] for dt in ['train', 'val', 'test']: if self.second_resp and dt == 'val': dt = 'valid' with open(data_path.format(df, dt)) as f: self.data += json.load(f) if self.second_resp: self.data = [d for d in self.data if len(d['dialog']) > 1] if os.path.exists(self.examples_idx_stack_path): with open(self.examples_idx_stack_path, 'rb') as handle: self.idx_stack = pickle.load(handle) else: self.idx_stack = [] self.add_idx_stack() self.save_idx_stack()
def __init__(self, opt, agents=None, shared=None): # Add passed in agents directly. self.task_type = 'sandbox' if opt['is_sandbox'] else 'live' self.agents = agents self.acts = [None] * len(agents) self.episodeDone = False self.opt = opt self.data = [] self.offensive_lang_detector = OffensiveLanguageDetector() self.rand_index = random.randint(0, self.opt["participants"] - 1) # read list of local images or links to S3 locations self.imgs = [ "/projects2/ParlAI/data/yfcc_images/1e22a9cf867d718551386b427c3b6d18.jpg", "/projects2/ParlAI/data/yfcc_images/96472caea58db27769f1c282e2ac0.jpg", "/projects2/ParlAI/data/yfcc_images/f09d8fb76822158de129acb0fef463.jpg", "/projects2/ParlAI/data/yfcc_images/6e4ccc739ff44ed11da20ad9892317.jpg", "/projects2/ParlAI/data/yfcc_images/e7e1844aa9e67cddc6ffe8804d76e45b.jpg", "/projects2/ParlAI/data/yfcc_images/5547b3852afec328a491a696ace99a.jpg", "/projects2/ParlAI/data/yfcc_images/b326345ae2b2bd14ebf74aaa31e571a.jpg", "/projects2/ParlAI/data/yfcc_images/75a13ebe4be7ab5b3f68f692d7db081.jpg", "/projects2/ParlAI/data/yfcc_images/246eea26a3fc2d886be795790a7495.jpg", "/projects2/ParlAI/data/yfcc_images/010722aa6d2327deddb4ead5e089ea.jpg" ] # read list of links from a local file self.links = [ "https://www.youtube.com/watch?v=7gUv0xcFqMk".replace( "watch?v=", "embed/"), "https://www.youtube.com/watch?v=6vYJyOGKCHE".replace( "watch?v=", "embed/"), "https://www.youtube.com/watch?v=3SJ0Rd7XU4Y".replace( "watch?v=", "embed/") ]
def __init__(self, opt, agents=None, shared=None, world_tag='NONE'): self.turn_idx = 0 self.task_type = 'sandbox' if opt['is_sandbox'] else 'live' self.chat_done = False self.world_tag = world_tag self.max_resp_time = opt['max_resp_time'] # in secs super().__init__(opt, agents, shared) self.agents = agents self.offensive_lang_detector = OffensiveLanguageDetector() self.agent = agents[0] self.data = [] self.exact_match = False self.num_images = opt['num_images'] self.multiple_personality = opt.get('multiple_personality', False) self.eval = 0 self.data_type = opt['task_type'] self.task_type_title = TASK_TYPE_TO_TITLE[opt['task_type']] self.config = TASK_TYPE_TO_CONFIG[opt['task_type']]
def __init__(self, opt, agents=None, shared=None, world_tag='NONE'): self.turn_idx = 0 self.task_type = 'sandbox' if opt['is_sandbox'] else 'live' self.chat_done = False self.world_tag = world_tag self.max_resp_time = opt['max_resp_time'] # in secs super().__init__(opt, agents, shared) self.agents = agents self.agent = agents[0] self.offensive_lang_detector = OffensiveLanguageDetector() self.data = [] self.exact_match = False self.num_images = opt['num_images'] self.second_resp = opt.get('second_response', False) self.config = config_first if not self.second_resp else config_second if opt.get('yfcc_path'): self.image_path = opt['yfcc_path'] else: self.image_path = os.path.join(opt['datapath'], 'yfcc_images')
class MTurkWizardOfWikipediaWorld(MultiAgentDialogWorld): """World where two agents have a dialogue; one chats freely, perhaps based on a persona, while the other is the 'wizard', who bases his/her responses on documents (i.e. sentences) retrieved based on what the other agent says. """ def __init__( self, opt, agents=None, shared=None, world_tag='NONE', ir_agent=None, task='', wiki_title_to_passage=None, ): self.turn_idx = 0 self.min_turns = opt['min_turns'] self.max_turns = opt['max_turns'] self.num_turns = np.random.randint(self.min_turns, self.max_turns) + 1 self.dialog = [] self.wizard_eval = 0 self.task_type = 'sandbox' if opt['is_sandbox'] else 'live' self.chat_done = False self.world_tag = world_tag self.max_resp_time = opt['max_resp_time'] # in secs self.num_passages_to_retrieve = opt['num_passages_retrieved'] super().__init__(opt, agents, shared) self.agents = sorted(agents, key=lambda x: x.id, reverse=random.random() <= 0.5) # Personas and retriever self.persona_generator = self.agents[0].persona_generator self.relevant_topics = [] while not self.relevant_topics: self.persona_to_topics = {} self.persona_idx, persona_data = self.persona_generator.pop_persona( ) for p in persona_data: if p[0] == ' ': p = p[1:] if p not in self.persona_to_topics: self.persona_to_topics[p] = [] topics = set(self.persona_generator.get_topics(p)) for t in topics: self.relevant_topics.append(t + ' ({})'.format(p)) self.persona_to_topics[p].append(t) self.ir_agent = ir_agent self.setup_tokenizer(opt) self.chosen_topic = '' self.chosen_topic_passage = {} self.OLD = OffensiveLanguageDetector() # Load the title to passage dictionary self.wiki_title_to_passage = wiki_title_to_passage def episode_done(self): return self.chat_done def setup_tokenizer(self, opt): try: import nltk except ImportError: raise ImportError('Please install nltk (e.g. pip install nltk).') # nltk-specific setup st_path = 'tokenizers/punkt/{0}.pickle'.format(opt['dict_language']) try: self.sent_tok = nltk.data.load(st_path) except LookupError: nltk.download('punkt') self.sent_tok = nltk.data.load(st_path) def sufficient_overlap(self, text, sent_dict): text_list = [ w[:4] for w in split_tokenize(text.lower()) if w not in STOPWORDS ] for _, sentence in sent_dict.items(): sentence_list = [ w[:4] for w in split_tokenize(sentence.lower()) if w not in STOPWORDS ] if len(set(text_list).intersection( set(sentence_list))) >= self.opt.get( 'word_overlap_threshold', 2): return True return False def parley(self): """Each agent acts; when the APPRENTICE says something, the WIZARD is given retrieved documents based on the text response""" self.turn_idx += 1 # Initial Message Value control_msg = {'episode_done': False} control_msg['id'] = 'SYSTEM' print(self.world_tag + ' is at turn {}...'.format(self.turn_idx)) '''First Turn: We give the first agent the list of topics to choose from ''' if self.turn_idx == 1: for idx, agent in enumerate(self.agents): '''If we are giving the persona, do that :)''' control_msg['text'] = self.get_instruction(tag='start', agent_id=agent.id) if agent.id == WIZARD: control_msg['description'] = config['wizard_onboarding'] else: control_msg['description'] = config[ 'apprentice_onboarding'] agent.observe(validate(control_msg)) if idx == 0: time.sleep(3) '''Send First Person the list of relevant topics''' self.agents[0].observe( validate({ 'id': 'SYSTEM', 'text': PICK_TOPIC_MSG, 'relevant_topics': self.relevant_topics, })) topic_act = self.agents[0].act(timeout=self.max_resp_time) timed_out = self.check_timeout(topic_act) if not timed_out: if self.agents[0].id == APPRENTICE: pick_msg = AFTER_PICK_TOPIC_MSG else: pick_msg = AFTER_PICK_TOPIC_WIZARD_MSG self.agents[0].observe({'id': 'SYSTEM', 'text': pick_msg}) self.chosen_topic = topic_act['text'] '''Now, send the wiki page for the chosen topic to the wizard''' for idx, agent in enumerate(self.agents): if agent.id == WIZARD: passage = self.wiki_title_to_passage.get( self.chosen_topic, '') if passage == '': break split = passage.split('\n') title = split[0] split = self.sent_tok.tokenize(" ".join(split[1:])) split[0] = split[0][1:] sentences = [] for sent in split: if len(sent) > 1: sentences.append(sent) if len(" ".join(sentences)) > MAX_DOC_LEN * 2: break msg_text = AFTER_PARTNER_PICK_TOPIC_WIZARD_MSG if idx == 1 else "" control_msg['text'] = msg_text control_msg['chosen_topic_passages'] = [[title, sentences]] agent.observe(validate(control_msg)) self.chosen_topic_passage = { 'topic': self.chosen_topic, 'full_passage': passage, 'shown_passage': sentences, } '''If we get to the min turns, inform turker that they can end if they want ''' if self.turn_idx == self.num_turns + 1: for agent in self.agents: control_msg['text'] = self.get_instruction( tag='exceed_min_turns') control_msg['exceed_min_turns'] = True agent.observe(validate(control_msg)) '''Otherwise, we proceed accordingly''' acts = self.acts for idx, agent in enumerate(self.agents): # Increase response time for wizard max_response_time = self.max_resp_time * (1 if agent.id == APPRENTICE else 1.5) acts[idx] = agent.act(timeout=max_response_time) self.check_timeout(acts[idx]) # If chat ends if acts[idx]['episode_done']: self.chat_done = True for ag in self.agents: if ag != agent and ag.some_agent_disconnected: control_msg['text'] = UNEXPECTED_DISCONNECTION_MSG ag.observe(validate(control_msg)) return if self.turn_idx > self.num_turns: for ag in self.agents: ag.observe(validate(acts[idx])) '''Have Apprentice Agent Eval Wizard Agent''' if ag.id == APPRENTICE: control_msg['text'] = EVAL_WIZARD_MSG control_msg['wizard_eval'] = True ag.observe(validate(control_msg)) act = ag.act(timeout=self.max_resp_time) self.check_timeout(act) try: w_ev = int(act['text']) w_ev = max(w_ev, 1) if w_ev <= 0 else min( w_ev, 5) except ValueError: # If there is a disconnect here w_ev = -1 self.wizard_eval = w_ev control_msg['text'] = CHAT_ENDED_MSG ag.observe(validate(control_msg)) return '''Set up msg info dict to save in dialog''' msg_info = { 'speaker': '{}_{}'.format(idx, agent.id), 'text': acts[idx]['text'], 'turn': self.turn_idx, 'time': time.time(), 'offensive': self.OLD.contains_offensive_language(acts[idx]['text']), } '''Get clicked passages and checked sentences from Wizard''' if 'clicked_passages' in acts[idx]: msg_info['clicked_passages'] = acts[idx]['clicked_passages'] checked_sents = {} for k, v in acts[idx]['checked_sentences'].items(): if k == 'no_passages_used': checked_sents[k] = v else: split = k.split('_') person = split[0] topic_idx = split[1] sent_idx = split[2] if person == 'partner': sub_passages = [ p.split('\n')[0] for p in self.dialog[-1]['full_passages'] ] topic = sub_passages[int(topic_idx)] elif person == 'self': sub_passages = [ p.split('\n')[0] for p in self.dialog[-2]['full_passages'] ] topic = sub_passages[int(topic_idx)] else: topic = self.chosen_topic cs_key = '_'.join( [person, '_'.join(topic.split(' ')), sent_idx]) checked_sents[cs_key] = v msg_info['checked_sentence'] = checked_sents msg_info['checked_passage'] = acts[idx]['checked_passages'] msg_info['good_message'] = self.sufficient_overlap( msg_info['text'], msg_info['checked_sentence']) '''Retrieve Passages''' ir_passages = self.retrieve_passages(copy.deepcopy(acts[idx])) passages = self.format_passages(ir_passages) msg_info['full_passages'] = ir_passages msg_info['shown_passages'] = passages if agent.id == WIZARD: '''Give Wizard the Relevant Passages''' control_msg['text'] = '' control_msg['self_retrieved_passages'] = passages agent.observe(validate(control_msg)) self.dialog.append(msg_info) for other_agent in self.agents: if other_agent != agent: other_agent.observe(validate(acts[idx])) if other_agent.id == WIZARD: control_msg[ 'text'] = PARTNER_RETRIEVED_PASSAGES_INST_MSG control_msg['partner_retrieved_passages'] = passages other_agent.observe(validate(control_msg)) def format_passages(self, ir_passages, max_length=MAX_DOC_LEN): passages = [] if len(ir_passages) == 1: # Didn't receive any passages passages.append(['No Passages Retrieved', []]) else: for passage in ir_passages: split = passage.split('\n') title = split[0] split = self.sent_tok.tokenize(" ".join(split[1:])) split[0] = split[0][1:] sentences = [] for sent in split: if len(sent) > 1: sentences.append(sent) if len(" ".join(sentences)) > max_length: break passages.append([title, sentences]) return passages def retrieve_passages(self, act, num_passages=None): if not num_passages: num_passages = self.num_passages_to_retrieve self.ir_agent.observe(act) action = self.ir_agent.act() passages = action.get('text_candidates', [action.get('text', "")]) return passages[:min(len(passages), num_passages)] def get_instruction(self, agent_id=None, tag='first'): if tag == 'start': start_msg = WIZARD_START_MSG if agent_id == WIZARD else APPRENTICE_START_MSG return start_msg.format(self.num_turns) if tag == 'timeout': return TIMEOUT_MSG if tag == 'exceed_min_turns': return EXCEED_MIN_TURNS_MSG.format(self.num_turns) def check_timeout(self, act): if act['text'] == '[TIMEOUT]' and act['episode_done']: control_msg = {'episode_done': True} control_msg['id'] = 'SYSTEM' control_msg['text'] = self.get_instruction(tag='timeout') for ag in self.agents: if ag.id != act['id']: ag.observe(validate(control_msg)) self.chat_done = True return True else: return False def reset_random(self): self.num_turns = np.random.randint(self.min_turns, self.max_turns) + 1 def save_data(self): # save persona_idx_stack convo_finished = self.turn_idx >= self.num_turns + 1 for ag in self.agents: if (ag.hit_is_abandoned or ag.hit_is_returned or ag.disconnected or ag.hit_is_expired): convo_finished = False if not convo_finished: self.persona_generator.push_persona(self.persona_idx) print("\n**Push persona {} back to stack. **\n".format( self.persona_idx)) self.agents[0].persona_generator.save_idx_stack() data_path = self.opt['data_path'] if not os.path.exists(data_path): os.makedirs(data_path) self.convo_finished = convo_finished self.wizard_worker = '' if convo_finished: filename = os.path.join( data_path, '{}_{}_{}.pkl'.format( time.strftime("%Y%m%d-%H%M%S"), np.random.randint(0, 1000), self.task_type, ), ) self.good_wiz, self.wizard_worker = self.check_wizard_quality() else: filename = os.path.join( data_path, '{}_{}_{}_incomplete.pkl'.format( time.strftime("%Y%m%d-%H%M%S"), np.random.randint(0, 1000), self.task_type, ), ) self.good_wiz = True pickle.dump( { 'persona': self.persona_to_topics, 'relevant_topics': self.relevant_topics, 'chosen_topic_passage': self.chosen_topic_passage, 'dialog': self.dialog, 'speaker_with_persona': self.agents[0].worker_id, 'workers': [ag.worker_id for ag in self.agents], 'n_turn': self.num_turns, 'hit_ids': [ag.hit_id for ag in self.agents], 'assignment_ids': [ag.assignment_id for ag in self.agents], 'wizard_eval': self.wizard_eval, 'chosen_topic': self.chosen_topic, 'wizard_good': convo_finished and self.good_wiz, 'good_wizard_worker': self.wizard_worker if self.good_wiz else '', 'bad_wizard_worker': self.wizard_worker if not self.good_wiz else '', }, open(filename, 'wb'), ) print('{}: Data successfully saved at {}.'.format( self.world_tag, filename)) def check_wizard_quality(self): '''Determines whether to soft-block this turker or not Only called if the conversation finishes Returns True if the Wizard is good ''' num_good_sents = len( list( filter( lambda info: 'good_message' in info and info['good_message' ], self.dialog, ))) wizard_worker = [w for w in self.agents if w.id == WIZARD][0].worker_id data_path = self.opt['current_working_dir'] bad_wizards = os.path.join(data_path, 'bad_wizards.txt') good_wizards = os.path.join(data_path, 'good_wizards.txt') if num_good_sents < self.opt['num_good_sentence_threshold']: if not self.opt['is_sandbox']: with open(bad_wizards, 'a') as f: f.write(wizard_worker + '\n') return False, wizard_worker else: if not self.opt['is_sandbox']: with open(good_wizards, 'a') as f: f.write(wizard_worker + '\n') return True, wizard_worker def review_work(self): global review_agent def review_agent(ag): role = ag.id for d in self.dialog: if role in d['speaker']: if d['offensive']: ag.reject_work(reason='Your HIT has been rejected ' 'because we detected offensive ' 'language in your submission.') Parallel(n_jobs=len(self.agents), backend='threading')(delayed(review_agent)(agent) for agent in self.agents) def shutdown(self): """Shutdown all mturk agents in parallel, otherwise if one mturk agent is disconnected then it could prevent other mturk agents from completing. """ global shutdown_agent def shutdown_agent(agent): agent.shutdown() Parallel(n_jobs=len(self.agents), backend='threading')(delayed(shutdown_agent)(agent) for agent in self.agents)
class MTurkPersonalityCaptionsWorld(MultiAgentDialogWorld): """World an agent observes ten images, with ten different personalities, and writes engaging comments about them """ def __init__(self, opt, agents=None, shared=None, world_tag='NONE'): self.turn_idx = 0 self.task_type = 'sandbox' if opt['is_sandbox'] else 'live' self.chat_done = False self.world_tag = world_tag self.max_resp_time = opt['max_resp_time'] # in secs super().__init__(opt, agents, shared) self.agents = agents self.offensive_lang_detector = OffensiveLanguageDetector() self.agent = agents[0] self.data = [] self.exact_match = False self.num_images = opt['num_images'] self.multiple_personality = opt.get('multiple_personality', False) self.eval = 0 self.data_type = opt['task_type'] self.task_type_title = TASK_TYPE_TO_TITLE[opt['task_type']] self.config = TASK_TYPE_TO_CONFIG[opt['task_type']] def episode_done(self): return self.chat_done def parley(self): """ COMMENTER is given an image, and is told to give a comment for the image """ # Initial Message Value control_msg = {'episode_done': False} control_msg['id'] = 'SYSTEM' # First, we give COMMENTER their personality instructions, and image while self.turn_idx < self.num_images: print(self.world_tag + ' is at turn {}...'.format(self.turn_idx)) # Send personality + image to turker if not self.multiple_personality: pers_tup = self.agent.personality_generator.pop_personality() self.pers_idx, personality = pers_tup img_tup = self.agent.image_generator.pop_image() self.image_hash, image_path = img_tup img = load_image(image_path) else: pair_tup = self.agent.personality_and_image_generator.pop_pair( ) self.pair_idx, personality, self.image_hash, image_path = pair_tup img = load_image(image_path) buffered = BytesIO() img.save(buffered, format='JPEG') encoded = str( base64.b64encode(buffered.getvalue()).decode('ascii')) control_msg['image'] = encoded if self.data_type == 'personality': personality_text = '<b><span style="color:blue">' \ '{}\n</span></b>'.format(personality.strip()) control_msg['personality_text'] = personality_text control_msg['text'] = self.get_instruction(tag='start', agent_id=self.agent.id, turn_num=self.turn_idx + 1) control_msg['description'] = self.config['task_description'] control_msg['task_type_title'] = self.task_type_title self.agent.observe(validate(control_msg)) time.sleep(1) # Collect comment from turker offensive_counter = 0 while offensive_counter < 3: idx = 0 acts = self.acts acts[idx] = self.agent.act(timeout=self.max_resp_time) agent_left = self.check_timeout(acts[idx]) if agent_left: break comment = acts[idx]['text'] offensive = self.offensive_lang_detector.contains_offensive_language( comment) if offensive: # Tell Turker to not be offensive! offensive_msg = { 'id': 'SYSTEM', 'text': OFFENSIVE_MSG, } self.agent.observe(validate(offensive_msg)) offensive_counter += 1 else: break if self.chat_done: break self.data.append({ 'comment': comment, 'personality': personality, 'image_hash': self.image_hash, 'image_path': image_path, 'contains_offensive_language': offensive, }) self.turn_idx += 1 if self.turn_idx == self.num_images: control_msg['text'] = CHAT_ENDED_MSG.format(self.num_images) control_msg['eval'] = True self.agent.observe(validate(control_msg)) act = self.agent.act(timeout=self.max_resp_time) self.check_timeout(act) try: ev_val = int(act['eval']) ev_text = act['text'] except BaseException: # If there is a disconnect here ev_val, ev_text = (-1, 'Eval not received') self.eval = {ev_val: ev_text} self.chat_done = True return def get_instruction(self, agent_id=None, tag='first', turn_num=0): if tag == 'start': return START_MSGS[self.data_type].format(turn_num) if tag == 'timeout': return TIMEOUT_MSG def check_timeout(self, act): if act['text'] == '[TIMEOUT]' and act['episode_done']: control_msg = {'episode_done': True} control_msg['id'] = 'SYSTEM' control_msg['text'] = self.get_instruction(tag='timeout') for ag in self.agents: if ag.id != act['id']: ag.observe(validate(control_msg)) self.chat_done = True return True elif act['text'] == '[DISCONNECT]': self.chat_done = True return True else: return False def save_data(self): convo_finished = True for ag in self.agents: if (ag.hit_is_abandoned or ag.hit_is_returned or ag.disconnected or ag.hit_is_expired): convo_finished = False if not convo_finished: if not self.multiple_personality: ag.personality_generator.push_personality(self.pers_idx) ag.image_generator.push_image(self.image_hash) print('\n**Push personality {} back to stack. **\n'.format( self.pers_idx)) print('\n**Push image {} back to stack. **\n'.format( self.image_hash)) else: ag.personality_and_image_generator.push_pair(self.pair_idx) print('\n**Push pair {} back to stack. **\n'.format( self.pair_idx)) self.agents[0].personality_generator.save_idx_stack() self.agents[0].image_generator.save_idx_stack() self.agents[0].personality_and_image_generator.save_idx_stack() data_path = self.opt['data_path'] if not os.path.exists(data_path): os.makedirs(data_path) if convo_finished: filename = os.path.join( data_path, '{}_{}_{}.pkl'.format(time.strftime('%Y%m%d-%H%M%S'), np.random.randint(0, 1000), self.task_type)) else: filename = os.path.join( data_path, '{}_{}_{}_incomplete.pkl'.format( time.strftime('%Y%m%d-%H%M%S'), np.random.randint(0, 1000), self.task_type)) comments = [d['comment'] for d in self.data] if len(comments) >= 2: c = comments[0] if _exact_match(c, comments[1:]): self.exact_match = True pickle.dump( { 'data': self.data, 'worker': self.agents[0].worker_id, 'hit_id': self.agents[0].hit_id, 'assignment_id': self.agents[0].assignment_id, 'exact_match': self.exact_match, 'task_eval': self.eval }, open(filename, 'wb')) print('{}: Data successfully saved at {}.'.format( self.world_tag, filename)) def review_work(self): global review_agent def review_agent(ag): contains_offense = any(d['contains_offensive_language'] for d in self.data) if contains_offense: ag.reject_work(reason='We have rejected this HIT because at ' 'least one of your comments contains ' 'offensive language') print('Rejected work for agent {} for ' 'offensive language'.format(ag.worker_id)) elif self.exact_match: ag.reject_work(reason='We have rejected this HIT because all ' 'of your comments are the exact same') print('Rejected work for agent {} for ' 'same comments'.format(ag.worker_id)) Parallel(n_jobs=len(self.agents), backend='threading')(delayed(review_agent)(agent) for agent in self.agents) def shutdown(self): """Shutdown all mturk agents in parallel, otherwise if one mturk agent is disconnected then it could prevent other mturk agents from completing. """ global shutdown_agent def shutdown_agent(agent): agent.shutdown() Parallel(n_jobs=len(self.agents), backend='threading')(delayed(shutdown_agent)(agent) for agent in self.agents)
class MTurkImageChatWorld(MultiAgentDialogWorld): """World where an agent observes 5 images and 5 comments, with 5 different personalities, and writes engaging responses to the comments """ def __init__(self, opt, agents=None, shared=None, world_tag='NONE'): self.turn_idx = 0 self.task_type = 'sandbox' if opt['is_sandbox'] else 'live' self.chat_done = False self.world_tag = world_tag self.max_resp_time = opt['max_resp_time'] # in secs super().__init__(opt, agents, shared) self.agents = agents self.agent = agents[0] self.offensive_lang_detector = OffensiveLanguageDetector() self.data = [] self.exact_match = False self.num_images = opt['num_images'] self.second_resp = opt.get('second_response', False) self.config = config_first if not self.second_resp else config_second if opt.get('yfcc_path'): self.image_path = opt['yfcc_path'] else: self.image_path = os.path.join(opt['datapath'], 'yfcc_images') def episode_done(self): return self.chat_done def parley(self): """RESPONDER is given an image and a comment, and is told to give a response for to the comment""" # Initial Message Value control_msg = {'episode_done': False} control_msg['id'] = 'SYSTEM' '''We only have to worry about 1 agent''' agent = self.agents[0] '''First, we give RESPONDER their personality instructions, and image ''' while self.turn_idx < self.num_images: print(self.world_tag + ' is at turn {}...'.format(self.turn_idx)) # Send personality + image + comment to turker self.example_num, example = self.agent.example_generator.pop_example( ) control_msg['text'] = self.get_instruction(tag='start', agent_id=agent.id, turn_num=self.turn_idx + 1) if self.second_resp: control_msg['comment_text'] = '<b><span style="color:red">' \ '{}\n</span></b>'.format( example['dialog'][0][1].strip()) control_msg['response_text'] = '<b><span style="color:blue">' \ '{}\n</span></b>'.format( example['dialog'][1][1].strip()) else: control_msg['comment_text'] = '<b><span style="color:red">' \ '{}\n</span></b>'.format( example['comment'].strip()) img = load_image( os.path.join(self.image_path, '{}.jpg'.format(example['image_hash']))) buffered = BytesIO() img.save(buffered, format='JPEG') encoded = str( base64.b64encode(buffered.getvalue()).decode('ascii')) control_msg['image'] = encoded if self.second_resp: self.pers_idx, personality = (-1, example['dialog'][0][0]) else: pers_tup = self.agent.personality_generator.pop_personality() self.pers_idx, personality = pers_tup personality_text = '<b><span style="color:{}">' \ '{}\n</span></b>'.format( 'blue' if not self.second_resp else 'red', personality.strip()) control_msg['personality_text'] = personality_text control_msg['description'] = self.config['task_description'] agent.observe(validate(control_msg)) time.sleep(1) # Collect comment from turker offensive_counter = 0 while offensive_counter < 3: idx = 0 acts = self.acts acts[idx] = agent.act(timeout=self.max_resp_time) agent_left = self.check_timeout(acts[idx]) if agent_left: break response = acts[idx]['text'] offensive = self.offensive_lang_detector.contains_offensive_language( response) if offensive: # Tell Turker to not be offensive! offensive_counter += 1 if offensive_counter == 3: break offensive_msg = { 'id': 'SYSTEM', 'text': OFFENSIVE_MSG, } agent.observe(validate(offensive_msg)) else: break if self.chat_done: break ex_to_save = example.copy() key = 'second_response' if self.second_resp else 'first_response' ex_to_save[key] = response ex_to_save['{}_personality'.format(key)] = personality ex_to_save['contains_offensive_language'] = offensive self.data.append(ex_to_save) self.turn_idx += 1 if self.turn_idx == self.num_images: control_msg['text'] = CHAT_ENDED_MSG.format(self.num_images) agent.observe(validate(control_msg)) self.chat_done = True return def get_instruction(self, agent_id=None, tag='first', turn_num=0): if tag == 'start': start_msg = START_MSG if not self.second_resp else START_MSG_SECOND_RESP return start_msg.format(turn_num) if tag == 'timeout': return TIMEOUT_MSG def check_timeout(self, act): if act['text'] == '[TIMEOUT]' and act['episode_done']: control_msg = {'episode_done': True} control_msg['id'] = 'SYSTEM' control_msg['text'] = self.get_instruction(tag='timeout') for ag in self.agents: if ag.id != act['id']: ag.observe(validate(control_msg)) self.chat_done = True return True elif act['text'] == '[DISCONNECT]': self.chat_done = True return True else: return False def save_data(self): convo_finished = True for ag in self.agents: if (ag.hit_is_abandoned or ag.hit_is_returned or ag.disconnected or ag.hit_is_expired): convo_finished = False if not convo_finished: if not self.second_resp: ag.personality_generator.push_personality(self.pers_idx) ag.example_generator.push_example(self.example_num) print('\n**Push personality {} back to stack. **\n'.format( self.pers_idx)) print('\n**Push image {} back to stack. **\n'.format( self.example_num)) if not self.second_resp: self.agents[0].personality_generator.save_idx_stack() self.agents[0].example_generator.save_idx_stack() data_path = self.opt['data_path'] if not os.path.exists(data_path): os.makedirs(data_path) if convo_finished: filename = os.path.join( data_path, '{}_{}_{}.pkl'.format(time.strftime('%Y%m%d-%H%M%S'), np.random.randint(0, 1000), self.task_type)) else: filename = os.path.join( data_path, '{}_{}_{}_incomplete.pkl'.format( time.strftime('%Y%m%d-%H%M%S'), np.random.randint(0, 1000), self.task_type)) key = 'second_response' if self.second_resp else 'first_response' responses = [d[key] for d in self.data] if len(responses) >= 2: c = responses[0] if _exact_match(c, responses[1:]): self.exact_match = True data_to_save = [ d for d in self.data if d['contains_offensive_language'] is None ] pickle.dump( { 'data': data_to_save, 'worker': self.agents[0].worker_id, 'hit_id': self.agents[0].hit_id, 'assignment_id': self.agents[0].assignment_id, 'exact_match': self.exact_match }, open(filename, 'wb')) print('{}: Data successfully saved at {}.'.format( self.world_tag, filename)) def review_work(self): global review_agent def review_agent(ag): contains_offense = any(d['contains_offensive_language'] for d in self.data) if contains_offense: ag.reject_work(reason='We have rejected this HIT because at ' 'least one of your comments ' 'contains offensive language') print( 'Rejected work for agent {} for offensive language'.format( ag.worker_id)) elif self.exact_match: ag.reject_work(reason='We have rejected this HIT because ' 'all of your comments are the exact same') print('Rejected work for agent {} for same comments'.format( ag.worker_id)) Parallel(n_jobs=len(self.agents), backend='threading')(delayed(review_agent)(agent) for agent in self.agents) def shutdown(self): """Shutdown all mturk agents in parallel, otherwise if one mturk agent is disconnected then it could prevent other mturk agents from completing. """ global shutdown_agent def shutdown_agent(agent): agent.shutdown() Parallel(n_jobs=len(self.agents), backend='threading')(delayed(shutdown_agent)(agent) for agent in self.agents)