示例#1
0
    def _get_persona_pool(self, opt, remove_duplicate=True):
        print("[loading persona pool from convai2 training data]")
        # Get episodes from training dataset
        datapath = make_path(opt, 'train.txt')
        episodes = []
        eps = []
        with open(datapath) as read:
            for line in read:
                msg = str_to_msg(line.rstrip('\n'))
                if msg:
                    # self.num_exs += 1
                    eps.append(msg)
                    if msg.get('episode_done', False):
                        episodes.append(eps)
                        eps = []
        if len(eps) > 0:
            # add last episode
            eps[-1].force_set('episode_done', True)
            episodes.append(eps)

        # Extract personas from episodes
        persona_set = OrderedSet()
        for episode in episodes:
            first_turn = episode[0]
            text = first_turn['text']
            persona, _ = _split_persona_and_context(text)
            persona_set.add(persona)

        # Remove duplicate
        if remove_duplicate:
            train_persona_fname = os.path.join(__PATH__,
                                               'train_persona_map.pkl')
            with open(train_persona_fname, 'rb') as fp:
                _train_personas = pickle.load(fp)
            train_personas = []
            for personas in _train_personas.values():
                longest_idx = 0
                longest_length = -1
                for idx, persona in enumerate(personas):
                    if len(persona) > longest_length:
                        longest_idx = idx
                        longest_length = len(persona)
                selected_persona = map(lambda x: f"your persona: {x}.",
                                       personas[longest_idx])
                selected_persona = '\n'.join(selected_persona)
                train_personas.append(selected_persona)
            persona_set = OrderedSet()
            for train_persona in train_personas:
                persona_set.add(train_persona)

        all_personas = []
        persona_to_idx = {}
        for i, persona in enumerate(persona_set):
            all_personas.append(persona)
            persona_to_idx[persona] = i

        print(f"Total {len(all_personas)} personas in dataset")

        return all_personas, persona_to_idx
示例#2
0
    def _setup_data(self, path):
        logging.info(f"Loading ParlAI text data: {path}")

        self.episodes = []
        self.num_exs = 0
        eps = []
        with PathManager.open(path, newline='\n', encoding='utf-8') as read:
            for line_no, line in enumerate(read, 1):
                msg = str_to_msg(line.rstrip('\n'))
                if msg and 'eval_labels' in msg:
                    raise ValueError(
                        f"It looks like you've written eval_labels as a key in your "
                        f"data file. This is not appropriate; labels will be converted "
                        f"for you automatically. This is happening on Line {line_no} "
                        f"in {path}. The line is:\n\t{line}")
                if msg and 'text' not in msg:
                    raise ValueError(
                        f'ParlaiDialogTeacher requires a "text" field in every '
                        f'entry, but one is missing in Line {line_no} in {path}. '
                        f'The line is:\n\t{line}')
                if msg and 'labels' not in msg:
                    raise ValueError(
                        f'ParlaiDialogTeacher requires a "labels" field in every '
                        f'entry, but one is missing in Line {line_no} in {path}. '
                        f'The line is:\n\t{line}')

                if (self.opt['bad_speaker_to_eval'] != 'all'
                        and self.opt['bad_speaker_to_eval'] !=
                        msg['speaker_to_eval']):
                    continue
                if (self.opt['bad_safety_mix'] != 'all'
                        and SAFETY_DICT[self.opt['bad_safety_mix']] !=
                        msg['labels'][0]):
                    continue
                if self.opt['bad_num_turns'] > 0:
                    dialog = msg['text'].split('\n')
                    msg.force_set(
                        'text', '\n'.join(dialog[-self.opt['bad_num_turns']:]))
                if msg:
                    self.num_exs += 1
                    eps.append(msg)
                    if msg.get('episode_done', False):
                        self.episodes.append(eps)
                        eps = []
        if len(eps) > 0:
            # add last episode
            eps[-1].force_set('episode_done', True)
            self.episodes.append(eps)
        if len(self.episodes) == 1 and line_no > 100:
            logging.error(
                f'The data in {path} looks like one very long episode. If this '
                f'is intentional, you may ignore this, but you MAY have a bug in '
                f'your data.')
示例#3
0
    def _get_context_pool(self, opt):
        print("[loading history pool from convai2 training data]")
        datapath = make_path(opt, 'train.txt')
        episodes = []
        eps = []
        with open(datapath) as read:
            for line in read:
                msg = str_to_msg(line.rstrip('\n'))
                if msg:
                    eps.append(msg)
                    if msg.get('episode_done', False):
                        episodes.append(eps)
                        eps = []
        if len(eps) > 0:
            # add last episode
            eps[-1].force_set('episode_done', True)
            episodes.append(eps)

        context_pool = defaultdict(list)
        for ep in episodes:
            context_pool[len(ep)].append([turn['labels'][0] for turn in ep])

        return dict(context_pool)
示例#4
0
    def _setup_data(self, path, datatype):
        # random.seed(self.opt['random_seed'])  # Set this for pick same distractor persona
        random.seed(46)  # Set this for pick same distractor persona
        # Data loading with script of ParlAIDialogTeacher
        print(f"[Loading ParlAI text data: {path}]")

        # Read data from ConvAI2
        convai2_datapath = make_path(self.opt, f'{datatype}_both_original.txt')
        convai2_episodes = self._load_convai2_data(convai2_datapath)

        if self.opt['eval_type'] == 'convai2':
            self.episodes = []
            self.num_exs = 0
            eps = []
            with open(path) as read:
                for line in read:
                    msg = str_to_msg(line.rstrip('\n'))
                    if msg:
                        self.num_exs += 1
                        eps.append(msg)
                        if msg.get('episode_done', False):
                            self.episodes.append(eps)
                            eps = []
            if len(eps) > 0:
                # add last episode
                eps[-1].force_set('episode_done', True)
                self.episodes.append(eps)
            # Add label candidates and partner's persona
            for episode_idx, episode in enumerate(self.episodes):
                for turn_idx, turn in enumerate(episode):
                    convai2_turn = convai2_episodes[episode_idx][turn_idx]
                    convai2_text = convai2_turn[0]
                    label_candidates = convai2_turn[3]

                    turn['label_candidates'] = label_candidates
                    if turn_idx == 0:
                        my_persona, partner_persona, _ = _split_personas_and_context(
                            convai2_text)
                        turn['partner_persona'] = partner_persona
                        turn['my_persona'] = my_persona
                    else:
                        turn['partner_persona'] = episode[0]['partner_persona']
                        turn['my_persona'] = episode[0]['my_persona']
        elif self.opt['eval_type'] == 'dnli':
            self.episodes = []
            self.num_exs = 0
            for eval_set in ['attributes', 'havenot', 'likedislike']:
                datapath = make_path(self.opt, f'{datatype}_{eval_set}.jsonl')
                with open(datapath, 'r') as fp:
                    for line in fp:
                        msg = json.loads(line)
                        msg['eval_set'] = eval_set
                        msg['episode_done'] = True

                        # Make 'text'
                        persona_lines = [
                            f'your persona: {x[:-2]}.' for x in msg['persona']
                        ]
                        utts = msg['prefix']

                        p1_token, p2_token = TorchAgent.P1_TOKEN, TorchAgent.P2_TOKEN
                        lines = persona_lines
                        # Identify the dialogue lines. It's assumed that p1 goes first.
                        for i, utt in enumerate(utts):
                            if i % 2 == 0:
                                lines.append(f'{p1_token} {utt}')
                            else:
                                lines.append(f'{p2_token} {utt}')
                        text = '\n'.join(lines)

                        msg['text'] = text

                        # Make 'label_candidates'
                        cands = msg['candidates']
                        msg['label_candidates'] = cands['label'] + cands['neg'][:10] \
                            + cands['similar'][:10] + cands['rand'][:10]

                        # Remove unused attributes
                        del msg['persona']
                        del msg['prefix']
                        del msg['triple']
                        del msg['relevant_persona_sentence']
                        del msg['candidates']

                        self.episodes.append([msg])
                        self.num_exs += 1

        # Get dialogue history pool
        context_pool = self._get_context_pool(self.opt)

        # Add distractor history
        if self.opt['world_cardinality'] > 0:
            for episode in self.episodes:
                gt_persona, first_context = _split_persona_and_context(
                    episode[0]['text'], self.opt['eval_type'])

                # Select distractor history
                if self.opt['eval_type'] == 'convai2':
                    num_turn = len(episode)
                else:
                    dialogue = first_context.split('\n')
                    num_turn = math.ceil(len(dialogue) / 2)
                    if num_turn < min(context_pool.keys()):
                        # orginal_num_turn = num_turn
                        num_turn = min(context_pool.keys())

                context_indices = list(range(len(context_pool[num_turn])))

                distractor_c_indices = random.sample(
                    context_indices, self.opt['world_cardinality'] - 1)
                distractor_contexts = itemgetter(*distractor_c_indices)(
                    context_pool[num_turn])

                # Make it to 'distractor_text'
                if self.opt['eval_type'] == 'convai2':
                    for turn_idx, turn in enumerate(episode):
                        turn['distractor_text'] = turn['labels'] + [
                            c[turn_idx] for c in distractor_contexts
                        ]
                        if turn_idx == 0:
                            turn['my_context'] = turn['labels']
                        else:
                            turn['my_context'] = episode[
                                turn_idx - 1]['my_context'] + turn['labels']
                else:
                    # DNLI
                    distractor_text = [episode[0]['text']]
                    for c in distractor_contexts:
                        copied_dialogue = copy.deepcopy(dialogue)
                        for turn_idx, utterance in enumerate(copied_dialogue):
                            if turn_idx % 2 == 1:
                                copied_dialogue[turn_idx] = p2_token + c[
                                    turn_idx // 2]
                        distractor_context = '\n'.join([gt_persona] +
                                                       copied_dialogue)
                        distractor_text.append(distractor_context)
                    episode[0]['distractor_text'] = distractor_text
示例#5
0
    def _setup_data(self, path, datatype):

        random.seed(46)

        # Data loading with script of ParlAIDialogTeacher
        print(f"[Loading ParlAI text data: {path}]")

        # Read data from ConvAI2
        convai2_datapath = make_path(self.opt, f'{datatype}_both_original.txt')
        convai2_episodes = self._load_convai2_data(convai2_datapath)

        # Get persona pool
        all_personas, persona_to_idx = self._get_persona_pool(self.opt)
        sorted_personas = self._get_sorted_persona_pool(datatype)

        if self.opt['eval_type'] == 'convai2':
            self.episodes = []
            self.num_exs = 0
            eps = []
            with open(path) as read:
                for line in read:
                    msg = str_to_msg(line.rstrip('\n'))
                    if msg:
                        self.num_exs += 1
                        eps.append(msg)
                        if msg.get('episode_done', False):
                            self.episodes.append(eps)
                            eps = []
            if len(eps) > 0:
                # add last episode
                eps[-1].force_set('episode_done', True)
                self.episodes.append(eps)
            # Add label candidates and partner's persona
            for episode_idx, episode in enumerate(self.episodes):
                for turn_idx, turn in enumerate(episode):
                    convai2_turn = convai2_episodes[episode_idx][turn_idx]
                    convai2_text = convai2_turn[0]
                    label_candidates = convai2_turn[3]

                    turn['label_candidates'] = label_candidates
                    if turn_idx == 0:
                        my_persona, partner_persona, _ = _split_personas_and_context(
                            convai2_text)
                        turn['partner_persona'] = partner_persona
                        turn['my_persona'] = my_persona
                    else:
                        turn['partner_persona'] = episode[0]['partner_persona']
                        turn['my_persona'] = episode[0]['my_persona']
        elif self.opt['eval_type'] == 'dnli':
            self.episodes = []
            self.num_exs = 0
            for eval_set in ['attributes', 'havenot', 'likedislike']:
                datapath = make_path(self.opt, f'{datatype}_{eval_set}.jsonl')
                with open(datapath, 'r') as fp:
                    for line in fp:
                        msg = json.loads(line)
                        msg['eval_set'] = eval_set
                        msg['episode_done'] = True

                        # Make 'text'
                        persona_lines = [
                            f'your persona: {x[:-2]}.' for x in msg['persona']
                        ]
                        utts = msg['prefix']

                        p1_token, p2_token = TorchAgent.P1_TOKEN, TorchAgent.P2_TOKEN
                        lines = persona_lines
                        # Identify the dialogue lines. It's assumed that p1 goes first.
                        for i, utt in enumerate(utts):
                            if i % 2 == 0:
                                lines.append(f'{p1_token} {utt}')
                            else:
                                lines.append(f'{p2_token} {utt}')
                        text = '\n'.join(lines)

                        msg['text'] = text

                        # Make 'label_candidates'
                        cands = msg['candidates']
                        msg['label_candidates'] = cands['label'] + cands['neg'][:10] \
                            + cands['similar'][:10] + cands['rand'][:10]

                        # Remove unused attributes
                        del msg['persona']
                        del msg['prefix']
                        del msg['triple']
                        del msg['relevant_persona_sentence']
                        del msg['candidates']

                        self.episodes.append([msg])
                        self.num_exs += 1

        # Add distractor personas
        if self.opt['world_cardinality'] > 0:
            num_all_personas = len(all_personas)
            persona_indices = list(range(num_all_personas))
            world_cardinality = self.opt['world_cardinality']
            for episode in self.episodes:
                gt_persona, first_context = _split_persona_and_context(
                    episode[0]['text'], self.opt['eval_type'])
                gt_persona_idx = persona_to_idx.get(gt_persona, -1)

                # Choose random distractor personas
                distractor_indices = random.sample(persona_indices,
                                                   world_cardinality - 1)
                while gt_persona_idx in distractor_indices:
                    # Resample if gt_persona is sampled
                    distractor_indices = random.sample(persona_indices,
                                                       world_cardinality - 1)
                distractor_personas = itemgetter(
                    *distractor_indices)(all_personas)
                distractor_personas = list(distractor_personas)

                # Make it to 'distractor_text'
                for turn_idx, turn in enumerate(episode):
                    if turn_idx == 0:
                        turn['distractor_text'] = [
                            '\n'.join([persona, first_context])
                            for persona in [gt_persona] + distractor_personas
                        ]
                    else:
                        turn['distractor_text'] = [turn['text']
                                                   ] * world_cardinality