def __init__(self, opt, datatype: str = 'train', seed: Optional[int] = None): """ Initalize the context generator. opt: only a 'datapath' key is required, to specify the ParlAI data folder """ if seed is not None: self.rng = random.Random(seed) else: self.rng = random.Random() convai2_opt = Opt({'datapath': opt['datapath'], 'datatype': datatype}) self.convai2_teacher = BothTeacher(convai2_opt) ed_opt = Opt({ 'datapath': opt['datapath'], 'datatype': datatype, 'train_experiencer_only': True, }) # Specify train_experiencer_only = True because we want to ensure that the text # will correspond to a Speaker utterance and the label to a Listener response self.ed_teacher = EmpatheticDialoguesTeacher(ed_opt) wow_opt = Opt({'datapath': opt['datapath'], 'datatype': datatype}) self.wow_teacher = WizardDialogKnowledgeTeacher(wow_opt) self.topic_to_persona_path = _topic_to_persona_path(opt) self.wow_topics_to_episode_idxes = self._setup_topics_to_episodes() self.persona_strings_to_wow_topics = self._setup_personas_to_topics()
class ContextGenerator: """ Generates contexts shown to crowdsourced workers when collecting BST conversations. This generator was used to generate the context information shown to workers at the beginning of a conversation, when crowdsourcing the conversations that make up the BST dataset. """ def __init__(self, opt, datatype: str = 'train', seed: Optional[int] = None): """ Initalize the context generator. opt: only a 'datapath' key is required, to specify the ParlAI data folder """ if seed is not None: self.rng = random.Random(seed) else: self.rng = random.Random() convai2_opt = Opt({'datapath': opt['datapath'], 'datatype': datatype}) self.convai2_teacher = BothTeacher(convai2_opt) ed_opt = Opt({ 'datapath': opt['datapath'], 'datatype': datatype, 'train_experiencer_only': True, }) # Specify train_experiencer_only = True because we want to ensure that the text # will correspond to a Speaker utterance and the label to a Listener response self.ed_teacher = EmpatheticDialoguesTeacher(ed_opt) wow_opt = Opt({'datapath': opt['datapath'], 'datatype': datatype}) self.wow_teacher = WizardDialogKnowledgeTeacher(wow_opt) self.topic_to_persona_path = _topic_to_persona_path(opt) self.wow_topics_to_episode_idxes = self._setup_topics_to_episodes() self.persona_strings_to_wow_topics = self._setup_personas_to_topics() def get_context(self) -> dict: """ Get context information to be shown at the beginning of one conversation. Values in return dict: - context_dataset: the dataset (ConvAI2, EmpatheticDialogues, or Wizard of Wikipedia) used to generate the context information. - persona_1_strings, persona_2_strings: 2 persona strings each for the two speakers, chosen randomly from the ConvAI2 dataset. If context_dataset == "wizard_of_wikipedia", these persona strings will be matched to the WoW topic returned in the "additional_context" field. - additional_context: provides additional bits of information to give context for the speakers. If context_dataset == "empathetic_dialogues", this is a situation from the start of an ED conversation. If context_dataset == "wizard_of_wikipedia", this is a topic from the WoW dataset that matches the persona strings. If context_dataset == "convai2", this is None. - person1_seed_utterance, person2_seed_utterance: two lines of a conversation from the dataset specified by "context_dataset". They will be shown to the speakers to "seed" the conversation, and the speakers continue from where the lines left off. """ # Determine which dataset we will show context for rand_value = self.rng.random() if rand_value < 1 / 3: context_dataset = 'convai2' elif rand_value < 2 / 3: context_dataset = 'empathetic_dialogues' else: context_dataset = 'wizard_of_wikipedia' if context_dataset == 'convai2': # Select episode episode_idx = self.rng.randrange( self.convai2_teacher.num_episodes()) # Extract personas persona_1_strings, persona_2_strings = self._extract_personas( episode_idx) # Sample persona strings selected_persona_1_strings = self.rng.sample(persona_1_strings, 2) selected_persona_2_strings = self.rng.sample(persona_2_strings, 2) # Select previous utterances num_entries = len(self.convai2_teacher.data.data[episode_idx]) entry_idx = self.rng.randrange(1, num_entries) # Don't select the first entry, which often doesn't include an apprentice # utterance chosen_entry = self.convai2_teacher.get(episode_idx, entry_idx=entry_idx) person1_seed_utterance = chosen_entry['text'] assert len(chosen_entry['labels']) == 1 person2_seed_utterance = chosen_entry['labels'][0] return { 'context_dataset': context_dataset, 'persona_1_strings': selected_persona_1_strings, 'persona_2_strings': selected_persona_2_strings, 'additional_context': None, 'person1_seed_utterance': person1_seed_utterance, 'person2_seed_utterance': person2_seed_utterance, } elif context_dataset == 'empathetic_dialogues': # Select episode persona_episode_idx = self.rng.randrange( self.convai2_teacher.num_episodes()) # Extract personas persona_1_strings, persona_2_strings = self._extract_personas( persona_episode_idx) # Sample persona strings selected_persona_1_strings = self.rng.sample(persona_1_strings, 2) selected_persona_2_strings = self.rng.sample(persona_2_strings, 2) # Select previous utterances episode_idx = self.rng.randrange(self.ed_teacher.num_episodes()) entry_idx = 0 # We'll only use the first pair of utterances entry = self.ed_teacher.get(episode_idx, entry_idx=entry_idx) situation = entry['situation'] speaker_utterance = entry['text'] assert len(entry['labels']) == 1 listener_response = entry['labels'][0] return { 'context_dataset': context_dataset, 'persona_1_strings': selected_persona_1_strings, 'persona_2_strings': selected_persona_2_strings, 'additional_context': situation, 'person1_seed_utterance': speaker_utterance, 'person2_seed_utterance': listener_response, } elif context_dataset == 'wizard_of_wikipedia': # Pull different personas until you get a pair for which at least one # sentence has a WoW topic bound to it num_tries = 0 while True: num_tries += 1 # Extract a random (matched) pair of personas persona_episode_idx = self.rng.randrange( self.convai2_teacher.num_episodes()) all_persona_strings = dict() all_persona_strings[1], all_persona_strings[ 2] = self._extract_personas(persona_episode_idx) # See if any of the persona strings have a matching WoW topic matching_persona_string_idxes = [] for persona_idx, persona_strings in all_persona_strings.items( ): for str_idx, str_ in enumerate(persona_strings): wow_topics = self.persona_strings_to_wow_topics[str_] if len(wow_topics) > 0: matching_persona_string_idxes.append( (persona_idx, str_idx)) if len(matching_persona_string_idxes) > 0: break print( f'{num_tries:d} try/tries needed to find a pair of personas with an ' f'associated WoW topic.') # Pick out the WoW topic and matching persona string matching_persona_idx, matching_persona_string_idx = self.rng.sample( matching_persona_string_idxes, k=1)[0] matching_persona_string = all_persona_strings[ matching_persona_idx][matching_persona_string_idx] wow_topic = self.rng.sample( self.persona_strings_to_wow_topics[matching_persona_string], k=1)[0] # Sample persona strings, making sure that we keep the one connected to the # WoW topic if matching_persona_idx == 1: remaining_persona_1_strings = [ str_ for str_ in all_persona_strings[1] if str_ != matching_persona_string ] selected_persona_1_strings = [ matching_persona_string, self.rng.sample(remaining_persona_1_strings, k=1)[0], ] self.rng.shuffle(selected_persona_1_strings) selected_persona_2_strings = self.rng.sample( all_persona_strings[2], 2) else: selected_persona_1_strings = self.rng.sample( all_persona_strings[1], 2) remaining_persona_2_strings = [ str_ for str_ in all_persona_strings[2] if str_ != matching_persona_string ] selected_persona_2_strings = [ matching_persona_string, self.rng.sample(remaining_persona_2_strings, k=1)[0], ] self.rng.shuffle(selected_persona_2_strings) # Sample WoW previous utterances, given the topic episode_idx = self.rng.sample( self.wow_topics_to_episode_idxes[wow_topic], k=1)[0] entry_idx = 1 # Select the second entry, which (unlike the first entry) will always have # two valid utterances and which will not usually be so far along in the # conversation that the new Turkers will be confused entry = self.wow_teacher.get(episode_idx, entry_idx=entry_idx) apprentice_utterance = entry['text'] assert len(entry['labels']) == 1 wizard_utterance = entry['labels'][0] return { 'context_dataset': context_dataset, 'persona_1_strings': selected_persona_1_strings, 'persona_2_strings': selected_persona_2_strings, 'additional_context': wow_topic, 'person1_seed_utterance': apprentice_utterance, 'person2_seed_utterance': wizard_utterance, } def _setup_personas_to_topics(self) -> Dict[str, List[str]]: """ Create a map from ConvAI2 personas to WoW topics that they correspond to. """ print('Starting to map personas to topics.') persona_strings_to_topics = defaultdict(list) with PathManager.open(self.topic_to_persona_path, 'r') as f: for line in f: match = re.fullmatch(r'([^[]+): (\[.+\])\n', line) topic = match.group(1) if topic not in self.wow_topics_to_episode_idxes: continue persona_strings = eval(match.group(2)) assert isinstance(persona_strings, list) for str_ in persona_strings: persona_strings_to_topics[str_].append(topic) print('Finished mapping personas to topics.') return persona_strings_to_topics def _setup_topics_to_episodes(self) -> Dict[str, List[int]]: """ Create a map from WoW topics to the indices of the WoW episodes that use them. """ print('Starting to map topics to episodes.') topics_to_episodes = defaultdict(list) for episode_idx in range(self.wow_teacher.num_episodes()): topic = self.wow_teacher.get(episode_idx, entry_idx=0)['chosen_topic'] topics_to_episodes[topic].append(episode_idx) print('Finished mapping topics to episodes.') return topics_to_episodes def _extract_personas(self, episode_idx: str) -> Tuple[List[str], List[str]]: """ For the given ConvAI2 conversation, return strings of both speakers' personas. """ first_entry = self.convai2_teacher.get(episode_idx, entry_idx=0) first_text_strings = first_entry['text'].split('\n') persona_1_strings = [] persona_2_strings = [] for str_ in first_text_strings[: -1]: # The last string is the first utterance if str_.startswith('your persona: '): # Here, "you" are Person 2 persona_2_strings.append(str_[len('your persona: '):]) elif str_.startswith("partner's persona: "): persona_1_strings.append(str_[len("partner's persona: "):]) else: raise ValueError('Persona string cannot be parsed!') return persona_1_strings, persona_2_strings