Example #1
0
    def conv_interact(self, data):
        context_tokens = [
            utter + [self.conv_bos_id] for utter in data['context_tokens']
        ]
        context_tokens[-1] = context_tokens[-1][:-1]
        context_tokens = [
            truncate(merge_utt(context_tokens),
                     max_length=self.context_truncate,
                     truncate_tail=False)
        ]
        context_tokens = padded_tensor(items=context_tokens,
                                       pad_idx=self.pad_token_idx,
                                       max_len=self.context_truncate,
                                       pad_tail=False)
        context_entities = [
            truncate(data['context_entities'],
                     self.entity_truncate,
                     truncate_tail=False)
        ]
        context_words = [
            truncate(data['context_words'],
                     self.word_truncate,
                     truncate_tail=False)
        ]

        return (context_tokens, context_tokens, context_tokens, context_tokens,
                padded_tensor(context_entities,
                              self.pad_entity_idx,
                              pad_tail=False),
                padded_tensor(context_words, self.pad_word_idx,
                              pad_tail=False), None)
Example #2
0
    def rec_interact(self, data):
        context = [self._process_rec_context(data['context_tokens'])]
        if 'interaction_history' in data:
            context_items = data['interaction_history'] + data['context_items']
        else:
            context_items = data['context_items']
        input_ids, input_mask, sample_negs = self._process_history(
            context_items)
        input_ids, input_mask, sample_negs = [input_ids], [input_mask
                                                           ], [sample_negs]

        context = padded_tensor(context,
                                self.pad_token_idx,
                                max_len=self.context_truncate)
        mask = (context != 0).long()

        return (context, mask,
                padded_tensor(input_ids,
                              pad_idx=self.pad_token_idx,
                              pad_tail=False,
                              max_len=self.item_truncate), None,
                padded_tensor(input_mask,
                              pad_idx=self.pad_token_idx,
                              pad_tail=False,
                              max_len=self.item_truncate),
                padded_tensor(sample_negs,
                              pad_idx=self.pad_token_idx,
                              pad_tail=False,
                              max_len=self.item_truncate), None)
Example #3
0
    def conv_batchify(self, batch):
        max_utterance_length = max(
            [max(conversation['utterance_lengths']) for conversation in batch])
        max_response_length = max(
            [len(conversation['response']) for conversation in batch])
        max_utterance_length = max(max_utterance_length, max_response_length)
        max_context_length = max(
            [conversation['context_length'] for conversation in batch])
        batch_context = []
        batch_context_length = []
        batch_utterance_lengths = []
        batch_request = []  # tensor
        batch_request_length = []
        batch_response = []

        for conversation in batch:
            padded_context = padded_tensor(conversation['context_tokens'],
                                           pad_idx=self.pad_token_idx,
                                           pad_tail=True,
                                           max_len=max_utterance_length)
            if len(conversation['context_tokens']) < max_context_length:
                pad_tensor = padded_context.new_full(
                    (max_context_length - len(conversation['context_tokens']),
                     max_utterance_length), self.pad_token_idx)
                padded_context = torch.cat((padded_context, pad_tensor), 0)
            batch_context.append(padded_context)
            batch_context_length.append(conversation['context_length'])
            batch_utterance_lengths.append(
                conversation['utterance_lengths'] + [0] *
                (max_context_length - len(conversation['context_tokens'])))

            request = conversation['request']
            batch_request_length.append(len(request))
            batch_request.append(request)

            response = copy(conversation['response'])
            # replace '^\d{5,6}$' by '__item__'
            for i in range(len(response)):
                if movie_pattern.match(self.ind2tok[response[i]]):
                    response[i] = self.item_token_idx
            batch_response.append(response)

        context = torch.stack(batch_context, dim=0)
        request = padded_tensor(batch_request,
                                self.pad_token_idx,
                                pad_tail=True,
                                max_len=max_utterance_length)
        response = padded_tensor(batch_response,
                                 self.pad_token_idx,
                                 pad_tail=True,
                                 max_len=max_utterance_length)  # (bs, utt_len)

        return {
            'context': context,
            'context_lengths': torch.tensor(batch_context_length),
            'utterance_lengths': torch.tensor(batch_utterance_lengths),
            'request': request,
            'request_lengths': torch.tensor(batch_request_length),
            'response': response
        }
Example #4
0
    def conv_batchify(self, batch):
        batch_context_tokens = []
        batch_context_entities = []
        batch_context_words = []
        batch_response = []
        for conv_dict in batch:
            batch_context_tokens.append(
                truncate(merge_utt(conv_dict['context_tokens']),
                         self.context_truncate,
                         truncate_tail=False))
            batch_context_entities.append(
                truncate(conv_dict['context_entities'],
                         self.entity_truncate,
                         truncate_tail=False))
            batch_context_words.append(
                truncate(conv_dict['context_words'],
                         self.word_truncate,
                         truncate_tail=False))
            batch_response.append(
                add_start_end_token_idx(truncate(conv_dict['response'],
                                                 self.response_truncate - 2),
                                        start_token_idx=self.start_token_idx,
                                        end_token_idx=self.end_token_idx))

        return (padded_tensor(batch_context_tokens,
                              self.pad_token_idx,
                              pad_tail=False),
                padded_tensor(batch_context_entities,
                              self.pad_entity_idx,
                              pad_tail=False),
                padded_tensor(batch_context_words,
                              self.pad_word_idx,
                              pad_tail=False),
                padded_tensor(batch_response, self.pad_token_idx))
Example #5
0
    def conv_batchify(self, batch):
        batch_context_tokens = []
        batch_context_entities = []
        batch_context_words = []
        batch_response = []
        flag = False
        batch_all_movies = []
        for conv_dict in batch:
            temp = add_start_end_token_idx(
                truncate(conv_dict['response'], self.response_truncate - 2),
                start_token_idx=self.start_token_idx,
                end_token_idx=self.end_token_idx)

            if temp.count(self.replace_token_idx) != 0:
                flag = True
            batch_context_tokens.append(
                truncate(merge_utt(conv_dict['context_tokens']),
                         self.context_truncate,
                         truncate_tail=False))
            batch_context_entities.append(
                truncate(conv_dict['context_entities'],
                         self.entity_truncate,
                         truncate_tail=False))
            batch_context_words.append(
                truncate(conv_dict['context_words'],
                         self.word_truncate,
                         truncate_tail=False))
            batch_response.append(
                add_start_end_token_idx(truncate(conv_dict['response'],
                                                 self.response_truncate - 2),
                                        start_token_idx=self.start_token_idx,
                                        end_token_idx=self.end_token_idx))

            batch_all_movies.append(
                truncate(
                    conv_dict['items'],
                    temp.count(self.replace_token_idx),
                    truncate_tail=False))  #only use movies, not all entities.
        if flag == False:  # zero slot in a batch
            return False

        return (padded_tensor(batch_context_tokens,
                              self.pad_token_idx,
                              pad_tail=False),
                padded_tensor(batch_context_entities,
                              self.pad_entity_idx,
                              pad_tail=False),
                padded_tensor(batch_context_words,
                              self.pad_word_idx,
                              pad_tail=False),
                padded_tensor(batch_response, self.pad_token_idx),
                padded_tensor(batch_all_movies,
                              self.pad_entity_idx,
                              pad_tail=False))
Example #6
0
    def rec_batchify(self, batch):
        batch_context = []
        batch_movie_id = []
        batch_input_ids = []
        batch_target_pos = []
        batch_input_mask = []
        batch_sample_negs = []

        for conv_dict in batch:
            context = self._process_rec_context(conv_dict['context_tokens'])
            batch_context.append(context)

            item_id = conv_dict['item']
            batch_movie_id.append(item_id)

            if 'interaction_history' in conv_dict:
                context_items = conv_dict['interaction_history'] + conv_dict[
                    'context_items']
            else:
                context_items = conv_dict['context_items']

            input_ids, target_pos, input_mask, sample_negs = self._process_history(
                context_items, item_id)
            batch_input_ids.append(input_ids)
            batch_target_pos.append(target_pos)
            batch_input_mask.append(input_mask)
            batch_sample_negs.append(sample_negs)

        batch_context = padded_tensor(batch_context,
                                      self.pad_token_idx,
                                      max_len=self.context_truncate)
        batch_mask = (batch_context != 0).long()

        return (batch_context, batch_mask,
                padded_tensor(batch_input_ids,
                              pad_idx=self.pad_token_idx,
                              pad_tail=False,
                              max_len=self.item_truncate),
                padded_tensor(batch_target_pos,
                              pad_idx=self.pad_token_idx,
                              pad_tail=False,
                              max_len=self.item_truncate),
                padded_tensor(batch_input_mask,
                              pad_idx=self.pad_token_idx,
                              pad_tail=False,
                              max_len=self.item_truncate),
                padded_tensor(batch_sample_negs,
                              pad_idx=self.pad_token_idx,
                              pad_tail=False,
                              max_len=self.item_truncate),
                torch.tensor(batch_movie_id))
Example #7
0
    def conv_batchify(self, batch):
        """get batch and corresponding roles
        """
        batch_roles = []
        batch_context_tokens = []
        batch_response = []

        for conv_dict in batch:
            batch_roles.append(0 if conv_dict['role'] == 'Seeker' else 1)
            context_tokens = [
                utter + [self.conv_bos_id]
                for utter in conv_dict['context_tokens']
            ]
            context_tokens[-1] = context_tokens[-1][:-1]
            batch_context_tokens.append(
                truncate(merge_utt(context_tokens),
                         max_length=self.context_truncate,
                         truncate_tail=False), )
            batch_response.append(
                add_start_end_token_idx(truncate(
                    conv_dict['response'],
                    max_length=self.response_truncate - 2),
                                        start_token_idx=self.start_token_idx,
                                        end_token_idx=self.end_token_idx))

        batch_context_tokens = padded_tensor(items=batch_context_tokens,
                                             pad_idx=self.pad_token_idx,
                                             max_len=self.context_truncate,
                                             pad_tail=False)
        batch_response = padded_tensor(batch_response,
                                       pad_idx=self.pad_token_idx,
                                       max_len=self.response_truncate,
                                       pad_tail=True)
        batch_input_ids = torch.cat((batch_context_tokens, batch_response),
                                    dim=1)
        batch_roles = torch.tensor(batch_roles)

        return (batch_roles, batch_input_ids, batch_context_tokens,
                batch_response)
Example #8
0
    def rec_batchify(self, batch):
        batch_context_entities = []
        batch_context_words = []
        batch_item = []
        for conv_dict in batch:
            batch_context_entities.append(
                truncate(conv_dict['context_entities'],
                         self.entity_truncate,
                         truncate_tail=False))
            batch_context_words.append(
                truncate(conv_dict['context_words'],
                         self.word_truncate,
                         truncate_tail=False))
            batch_item.append(conv_dict['item'])

        return (padded_tensor(batch_context_entities,
                              self.pad_entity_idx,
                              pad_tail=False),
                padded_tensor(batch_context_words,
                              self.pad_word_idx,
                              pad_tail=False),
                get_onehot(batch_context_entities, self.n_entity),
                torch.tensor(batch_item, dtype=torch.long))
Example #9
0
    def pretrain_batchify(self, batch):
        batch_context_entities = []
        batch_context_words = []
        for conv_dict in batch:
            batch_context_entities.append(
                truncate(conv_dict['context_entities'],
                         self.entity_truncate,
                         truncate_tail=False))
            batch_context_words.append(
                truncate(conv_dict['context_words'],
                         self.word_truncate,
                         truncate_tail=False))

        return (padded_tensor(batch_context_words,
                              self.pad_word_idx,
                              pad_tail=False),
                get_onehot(batch_context_entities, self.n_entity))
Example #10
0
    def rec_batchify(self, batch):
        batch_context = []
        batch_movie_id = []

        for conv_dict in batch:
            context = self._process_rec_context(conv_dict['context_tokens'])
            batch_context.append(context)

            item_id = conv_dict['item']
            batch_movie_id.append(item_id)

        batch_context = padded_tensor(batch_context,
                                      self.pad_token_idx,
                                      max_len=self.context_truncate)
        batch_mask = (batch_context != self.pad_token_idx).long()

        return (batch_context, batch_mask, torch.tensor(batch_movie_id))
Example #11
0
    def policy_batchify(self, batch):
        batch_context = []
        batch_context_policy = []
        batch_user_profile = []
        batch_target = []

        for conv_dict in batch:
            final_topic = conv_dict['final']
            final_topic = [[
                self.tok2ind.get(token, self.unk_token_idx)
                for token in self.ind2topic[topic_id]
            ] for topic_id in final_topic[1]]
            final_topic = merge_utt(final_topic, self.word_split_idx, False,
                                    self.sep_id)

            context = conv_dict['context_tokens']
            context = merge_utt(context, self.sent_split_idx, False,
                                self.sep_id)
            context += final_topic
            context = add_start_end_token_idx(truncate(
                context,
                max_length=self.context_truncate - 1,
                truncate_tail=False),
                                              start_token_idx=self.cls_id)
            batch_context.append(context)

            # [topic, topic, ..., topic]
            context_policy = []
            for policies_one_turn in conv_dict['context_policy']:
                if len(policies_one_turn) != 0:
                    for policy in policies_one_turn:
                        for topic_id in policy[1]:
                            if topic_id != self.pad_topic_idx:
                                policy = []
                                for token in self.ind2topic[topic_id]:
                                    policy.append(
                                        self.tok2ind.get(
                                            token, self.unk_token_idx))
                                context_policy.append(policy)
            context_policy = merge_utt(context_policy, self.word_split_idx,
                                       False)
            context_policy = add_start_end_token_idx(
                context_policy,
                start_token_idx=self.cls_id,
                end_token_idx=self.sep_id)
            context_policy += final_topic
            batch_context_policy.append(context_policy)

            batch_user_profile.extend(conv_dict['user_profile'])

            batch_target.append(conv_dict['target_topic'])

        batch_context = padded_tensor(batch_context,
                                      pad_idx=self.pad_token_idx,
                                      pad_tail=True,
                                      max_len=self.context_truncate)
        batch_cotnext_mask = (batch_context != 0).long()
        batch_context_policy = padded_tensor(batch_context_policy,
                                             pad_idx=self.pad_token_idx,
                                             pad_tail=True)
        batch_context_policy_mask = (batch_context_policy != 0).long()
        batch_user_profile = padded_tensor(batch_user_profile,
                                           pad_idx=self.pad_token_idx,
                                           pad_tail=True)
        batch_user_profile_mask = (batch_user_profile != 0).long()
        batch_target = torch.tensor(batch_target, dtype=torch.long)

        return (batch_context, batch_cotnext_mask, batch_context_policy,
                batch_context_policy_mask, batch_user_profile,
                batch_user_profile_mask, batch_target)
Example #12
0
    def conv_batchify(self, batch):
        batch_context_tokens = []
        batch_enhanced_context_tokens = []
        batch_response = []
        batch_context_entities = []
        batch_context_words = []
        for conv_dict in batch:
            context_tokens = [
                utter + [self.conv_bos_id]
                for utter in conv_dict['context_tokens']
            ]
            context_tokens[-1] = context_tokens[-1][:-1]
            batch_context_tokens.append(
                truncate(merge_utt(context_tokens),
                         max_length=self.context_truncate,
                         truncate_tail=False), )
            batch_response.append(
                add_start_end_token_idx(truncate(
                    conv_dict['response'],
                    max_length=self.response_truncate - 2),
                                        start_token_idx=self.start_token_idx,
                                        end_token_idx=self.end_token_idx))
            batch_context_entities.append(
                truncate(conv_dict['context_entities'],
                         self.entity_truncate,
                         truncate_tail=False))
            batch_context_words.append(
                truncate(conv_dict['context_words'],
                         self.word_truncate,
                         truncate_tail=False))

            enhanced_topic = []
            if 'target' in conv_dict:
                for target_policy in conv_dict['target']:
                    topic_variable = target_policy[1]
                    if isinstance(topic_variable, list):
                        for topic in topic_variable:
                            enhanced_topic.append(topic)
                enhanced_topic = [[
                    self.tok2ind.get(token, self.unk_token_idx)
                    for token in self.ind2topic[topic_id]
                ] for topic_id in enhanced_topic]
                enhanced_topic = merge_utt(enhanced_topic, self.word_split_idx,
                                           False, self.sent_split_idx)

            enhanced_movie = []
            if 'items' in conv_dict:
                for movie_id in conv_dict['items']:
                    enhanced_movie.append(movie_id)
                enhanced_movie = [[
                    self.tok2ind.get(token, self.unk_token_idx)
                    for token in self.id2entity[movie_id].split('(')[0]
                ] for movie_id in enhanced_movie]
                enhanced_movie = truncate(merge_utt(enhanced_movie,
                                                    self.word_split_idx,
                                                    self.sent_split_idx),
                                          self.item_truncate,
                                          truncate_tail=False)

            if len(enhanced_movie) != 0:
                enhanced_context_tokens = enhanced_movie + truncate(
                    batch_context_tokens[-1],
                    max_length=self.context_truncate - len(enhanced_movie),
                    truncate_tail=False)
            elif len(enhanced_topic) != 0:
                enhanced_context_tokens = enhanced_topic + truncate(
                    batch_context_tokens[-1],
                    max_length=self.context_truncate - len(enhanced_topic),
                    truncate_tail=False)
            else:
                enhanced_context_tokens = batch_context_tokens[-1]
            batch_enhanced_context_tokens.append(enhanced_context_tokens)

        batch_context_tokens = padded_tensor(items=batch_context_tokens,
                                             pad_idx=self.pad_token_idx,
                                             max_len=self.context_truncate,
                                             pad_tail=False)
        batch_response = padded_tensor(batch_response,
                                       pad_idx=self.pad_token_idx,
                                       max_len=self.response_truncate,
                                       pad_tail=True)
        batch_input_ids = torch.cat((batch_context_tokens, batch_response),
                                    dim=1)
        batch_enhanced_context_tokens = padded_tensor(
            items=batch_enhanced_context_tokens,
            pad_idx=self.pad_token_idx,
            max_len=self.context_truncate,
            pad_tail=False)
        batch_enhanced_input_ids = torch.cat(
            (batch_enhanced_context_tokens, batch_response), dim=1)

        return (batch_enhanced_input_ids, batch_enhanced_context_tokens,
                batch_input_ids, batch_context_tokens,
                padded_tensor(batch_context_entities,
                              self.pad_entity_idx,
                              pad_tail=False),
                padded_tensor(batch_context_words,
                              self.pad_word_idx,
                              pad_tail=False), batch_response)