Ejemplo n.º 1
0
 def conv_process_fn(self):
     dataset = []
     for conversation in tqdm(self.dataset):
         if conversation['role'] != 'Recommender':
             continue
         context_tokens = [
             truncate(utterance,
                      self.utterance_truncate,
                      truncate_tail=True)
             for utterance in conversation['context_tokens']
         ]
         context_tokens = truncate(context_tokens,
                                   self.conversation_truncate,
                                   truncate_tail=True)
         context_length = len(context_tokens)
         utterance_lengths = [
             len(utterance) for utterance in context_tokens
         ]
         request = context_tokens[-1]
         response = truncate(conversation['response'],
                             self.utterance_truncate,
                             truncate_tail=True)
         dataset.append({
             'context_tokens': context_tokens,
             'context_length': context_length,
             'utterance_lengths': utterance_lengths,
             'request': request,
             'response': response
         })
     return dataset
Ejemplo n.º 2
0
    def conv_interact(self, data):
        context_tokens = [
            utter + [self.conv_bos_id] for utter in data['context_tokens']
        ]
        context_tokens[-1] = context_tokens[-1][:-1]
        context_tokens = [
            truncate(merge_utt(context_tokens),
                     max_length=self.context_truncate,
                     truncate_tail=False)
        ]
        context_tokens = padded_tensor(items=context_tokens,
                                       pad_idx=self.pad_token_idx,
                                       max_len=self.context_truncate,
                                       pad_tail=False)
        context_entities = [
            truncate(data['context_entities'],
                     self.entity_truncate,
                     truncate_tail=False)
        ]
        context_words = [
            truncate(data['context_words'],
                     self.word_truncate,
                     truncate_tail=False)
        ]

        return (context_tokens, context_tokens, context_tokens, context_tokens,
                padded_tensor(context_entities,
                              self.pad_entity_idx,
                              pad_tail=False),
                padded_tensor(context_words, self.pad_word_idx,
                              pad_tail=False), None)
Ejemplo n.º 3
0
    def conv_batchify(self, batch):
        batch_context_tokens = []
        batch_context_entities = []
        batch_context_words = []
        batch_response = []
        for conv_dict in batch:
            batch_context_tokens.append(
                truncate(merge_utt(conv_dict['context_tokens']),
                         self.context_truncate,
                         truncate_tail=False))
            batch_context_entities.append(
                truncate(conv_dict['context_entities'],
                         self.entity_truncate,
                         truncate_tail=False))
            batch_context_words.append(
                truncate(conv_dict['context_words'],
                         self.word_truncate,
                         truncate_tail=False))
            batch_response.append(
                add_start_end_token_idx(truncate(conv_dict['response'],
                                                 self.response_truncate - 2),
                                        start_token_idx=self.start_token_idx,
                                        end_token_idx=self.end_token_idx))

        return (padded_tensor(batch_context_tokens,
                              self.pad_token_idx,
                              pad_tail=False),
                padded_tensor(batch_context_entities,
                              self.pad_entity_idx,
                              pad_tail=False),
                padded_tensor(batch_context_words,
                              self.pad_word_idx,
                              pad_tail=False),
                padded_tensor(batch_response, self.pad_token_idx))
Ejemplo n.º 4
0
    def pretrain_batchify(self, batch):
        batch_context_entities = []
        batch_context_words = []
        for conv_dict in batch:
            batch_context_entities.append(
                truncate(conv_dict['context_entities'],
                         self.entity_truncate,
                         truncate_tail=False))
            batch_context_words.append(
                truncate(conv_dict['context_words'],
                         self.word_truncate,
                         truncate_tail=False))

        return (padded_tensor(batch_context_words,
                              self.pad_word_idx,
                              pad_tail=False),
                get_onehot(batch_context_entities, self.n_entity))
Ejemplo n.º 5
0
 def _process_rec_context(self, context_tokens):
     compact_context = []
     for i, utterance in enumerate(context_tokens):
         if i != 0:
             utterance.insert(0, self.sent_split_idx)
         compact_context.append(utterance)
     compat_context = truncate(merge_utt(compact_context),
                               self.context_truncate - 2,
                               truncate_tail=False)
     compat_context = add_start_end_token_idx(compat_context,
                                              self.start_token_idx,
                                              self.end_token_idx)
     return compat_context
Ejemplo n.º 6
0
    def conv_batchify(self, batch):
        """get batch and corresponding roles
        """
        batch_roles = []
        batch_context_tokens = []
        batch_response = []

        for conv_dict in batch:
            batch_roles.append(0 if conv_dict['role'] == 'Seeker' else 1)
            context_tokens = [
                utter + [self.conv_bos_id]
                for utter in conv_dict['context_tokens']
            ]
            context_tokens[-1] = context_tokens[-1][:-1]
            batch_context_tokens.append(
                truncate(merge_utt(context_tokens),
                         max_length=self.context_truncate,
                         truncate_tail=False), )
            batch_response.append(
                add_start_end_token_idx(truncate(
                    conv_dict['response'],
                    max_length=self.response_truncate - 2),
                                        start_token_idx=self.start_token_idx,
                                        end_token_idx=self.end_token_idx))

        batch_context_tokens = padded_tensor(items=batch_context_tokens,
                                             pad_idx=self.pad_token_idx,
                                             max_len=self.context_truncate,
                                             pad_tail=False)
        batch_response = padded_tensor(batch_response,
                                       pad_idx=self.pad_token_idx,
                                       max_len=self.response_truncate,
                                       pad_tail=True)
        batch_input_ids = torch.cat((batch_context_tokens, batch_response),
                                    dim=1)
        batch_roles = torch.tensor(batch_roles)

        return (batch_roles, batch_input_ids, batch_context_tokens,
                batch_response)
Ejemplo n.º 7
0
    def rec_batchify(self, batch):
        batch_context_entities = []
        batch_context_words = []
        batch_item = []
        for conv_dict in batch:
            batch_context_entities.append(
                truncate(conv_dict['context_entities'],
                         self.entity_truncate,
                         truncate_tail=False))
            batch_context_words.append(
                truncate(conv_dict['context_words'],
                         self.word_truncate,
                         truncate_tail=False))
            batch_item.append(conv_dict['item'])

        return (padded_tensor(batch_context_entities,
                              self.pad_entity_idx,
                              pad_tail=False),
                padded_tensor(batch_context_words,
                              self.pad_word_idx,
                              pad_tail=False),
                get_onehot(batch_context_entities, self.n_entity),
                torch.tensor(batch_item, dtype=torch.long))
Ejemplo n.º 8
0
    def _process_history(self, context_items, item_id=None):
        input_ids = truncate(context_items,
                             max_length=self.item_truncate,
                             truncate_tail=False)
        input_mask = [1] * len(input_ids)
        sample_negs = []
        seq_set = set(input_ids)
        for _ in input_ids:
            sample_negs.append(self._neg_sample(seq_set))

        if item_id is not None:
            target_pos = input_ids[1:] + [item_id]
            return input_ids, target_pos, input_mask, sample_negs
        else:
            return input_ids, input_mask, sample_negs
Ejemplo n.º 9
0
    def conv_batchify(self, batch):
        batch_context_tokens = []
        batch_context_entities = []
        batch_context_words = []
        batch_response = []
        flag = False
        batch_all_movies = []
        for conv_dict in batch:
            temp = add_start_end_token_idx(
                truncate(conv_dict['response'], self.response_truncate - 2),
                start_token_idx=self.start_token_idx,
                end_token_idx=self.end_token_idx)

            if temp.count(self.replace_token_idx) != 0:
                flag = True
            batch_context_tokens.append(
                truncate(merge_utt(conv_dict['context_tokens']),
                         self.context_truncate,
                         truncate_tail=False))
            batch_context_entities.append(
                truncate(conv_dict['context_entities'],
                         self.entity_truncate,
                         truncate_tail=False))
            batch_context_words.append(
                truncate(conv_dict['context_words'],
                         self.word_truncate,
                         truncate_tail=False))
            batch_response.append(
                add_start_end_token_idx(truncate(conv_dict['response'],
                                                 self.response_truncate - 2),
                                        start_token_idx=self.start_token_idx,
                                        end_token_idx=self.end_token_idx))

            batch_all_movies.append(
                truncate(
                    conv_dict['items'],
                    temp.count(self.replace_token_idx),
                    truncate_tail=False))  #only use movies, not all entities.
        if flag == False:  # zero slot in a batch
            return False

        return (padded_tensor(batch_context_tokens,
                              self.pad_token_idx,
                              pad_tail=False),
                padded_tensor(batch_context_entities,
                              self.pad_entity_idx,
                              pad_tail=False),
                padded_tensor(batch_context_words,
                              self.pad_word_idx,
                              pad_tail=False),
                padded_tensor(batch_response, self.pad_token_idx),
                padded_tensor(batch_all_movies,
                              self.pad_entity_idx,
                              pad_tail=False))
Ejemplo n.º 10
0
    def policy_batchify(self, batch):
        batch_context = []
        batch_context_policy = []
        batch_user_profile = []
        batch_target = []

        for conv_dict in batch:
            final_topic = conv_dict['final']
            final_topic = [[
                self.tok2ind.get(token, self.unk_token_idx)
                for token in self.ind2topic[topic_id]
            ] for topic_id in final_topic[1]]
            final_topic = merge_utt(final_topic, self.word_split_idx, False,
                                    self.sep_id)

            context = conv_dict['context_tokens']
            context = merge_utt(context, self.sent_split_idx, False,
                                self.sep_id)
            context += final_topic
            context = add_start_end_token_idx(truncate(
                context,
                max_length=self.context_truncate - 1,
                truncate_tail=False),
                                              start_token_idx=self.cls_id)
            batch_context.append(context)

            # [topic, topic, ..., topic]
            context_policy = []
            for policies_one_turn in conv_dict['context_policy']:
                if len(policies_one_turn) != 0:
                    for policy in policies_one_turn:
                        for topic_id in policy[1]:
                            if topic_id != self.pad_topic_idx:
                                policy = []
                                for token in self.ind2topic[topic_id]:
                                    policy.append(
                                        self.tok2ind.get(
                                            token, self.unk_token_idx))
                                context_policy.append(policy)
            context_policy = merge_utt(context_policy, self.word_split_idx,
                                       False)
            context_policy = add_start_end_token_idx(
                context_policy,
                start_token_idx=self.cls_id,
                end_token_idx=self.sep_id)
            context_policy += final_topic
            batch_context_policy.append(context_policy)

            batch_user_profile.extend(conv_dict['user_profile'])

            batch_target.append(conv_dict['target_topic'])

        batch_context = padded_tensor(batch_context,
                                      pad_idx=self.pad_token_idx,
                                      pad_tail=True,
                                      max_len=self.context_truncate)
        batch_cotnext_mask = (batch_context != 0).long()
        batch_context_policy = padded_tensor(batch_context_policy,
                                             pad_idx=self.pad_token_idx,
                                             pad_tail=True)
        batch_context_policy_mask = (batch_context_policy != 0).long()
        batch_user_profile = padded_tensor(batch_user_profile,
                                           pad_idx=self.pad_token_idx,
                                           pad_tail=True)
        batch_user_profile_mask = (batch_user_profile != 0).long()
        batch_target = torch.tensor(batch_target, dtype=torch.long)

        return (batch_context, batch_cotnext_mask, batch_context_policy,
                batch_context_policy_mask, batch_user_profile,
                batch_user_profile_mask, batch_target)
Ejemplo n.º 11
0
    def conv_batchify(self, batch):
        batch_context_tokens = []
        batch_enhanced_context_tokens = []
        batch_response = []
        batch_context_entities = []
        batch_context_words = []
        for conv_dict in batch:
            context_tokens = [
                utter + [self.conv_bos_id]
                for utter in conv_dict['context_tokens']
            ]
            context_tokens[-1] = context_tokens[-1][:-1]
            batch_context_tokens.append(
                truncate(merge_utt(context_tokens),
                         max_length=self.context_truncate,
                         truncate_tail=False), )
            batch_response.append(
                add_start_end_token_idx(truncate(
                    conv_dict['response'],
                    max_length=self.response_truncate - 2),
                                        start_token_idx=self.start_token_idx,
                                        end_token_idx=self.end_token_idx))
            batch_context_entities.append(
                truncate(conv_dict['context_entities'],
                         self.entity_truncate,
                         truncate_tail=False))
            batch_context_words.append(
                truncate(conv_dict['context_words'],
                         self.word_truncate,
                         truncate_tail=False))

            enhanced_topic = []
            if 'target' in conv_dict:
                for target_policy in conv_dict['target']:
                    topic_variable = target_policy[1]
                    if isinstance(topic_variable, list):
                        for topic in topic_variable:
                            enhanced_topic.append(topic)
                enhanced_topic = [[
                    self.tok2ind.get(token, self.unk_token_idx)
                    for token in self.ind2topic[topic_id]
                ] for topic_id in enhanced_topic]
                enhanced_topic = merge_utt(enhanced_topic, self.word_split_idx,
                                           False, self.sent_split_idx)

            enhanced_movie = []
            if 'items' in conv_dict:
                for movie_id in conv_dict['items']:
                    enhanced_movie.append(movie_id)
                enhanced_movie = [[
                    self.tok2ind.get(token, self.unk_token_idx)
                    for token in self.id2entity[movie_id].split('(')[0]
                ] for movie_id in enhanced_movie]
                enhanced_movie = truncate(merge_utt(enhanced_movie,
                                                    self.word_split_idx,
                                                    self.sent_split_idx),
                                          self.item_truncate,
                                          truncate_tail=False)

            if len(enhanced_movie) != 0:
                enhanced_context_tokens = enhanced_movie + truncate(
                    batch_context_tokens[-1],
                    max_length=self.context_truncate - len(enhanced_movie),
                    truncate_tail=False)
            elif len(enhanced_topic) != 0:
                enhanced_context_tokens = enhanced_topic + truncate(
                    batch_context_tokens[-1],
                    max_length=self.context_truncate - len(enhanced_topic),
                    truncate_tail=False)
            else:
                enhanced_context_tokens = batch_context_tokens[-1]
            batch_enhanced_context_tokens.append(enhanced_context_tokens)

        batch_context_tokens = padded_tensor(items=batch_context_tokens,
                                             pad_idx=self.pad_token_idx,
                                             max_len=self.context_truncate,
                                             pad_tail=False)
        batch_response = padded_tensor(batch_response,
                                       pad_idx=self.pad_token_idx,
                                       max_len=self.response_truncate,
                                       pad_tail=True)
        batch_input_ids = torch.cat((batch_context_tokens, batch_response),
                                    dim=1)
        batch_enhanced_context_tokens = padded_tensor(
            items=batch_enhanced_context_tokens,
            pad_idx=self.pad_token_idx,
            max_len=self.context_truncate,
            pad_tail=False)
        batch_enhanced_input_ids = torch.cat(
            (batch_enhanced_context_tokens, batch_response), dim=1)

        return (batch_enhanced_input_ids, batch_enhanced_context_tokens,
                batch_input_ids, batch_context_tokens,
                padded_tensor(batch_context_entities,
                              self.pad_entity_idx,
                              pad_tail=False),
                padded_tensor(batch_context_words,
                              self.pad_word_idx,
                              pad_tail=False), batch_response)