def conv_interact(self, data): context_tokens = [ utter + [self.conv_bos_id] for utter in data['context_tokens'] ] context_tokens[-1] = context_tokens[-1][:-1] context_tokens = [ truncate(merge_utt(context_tokens), max_length=self.context_truncate, truncate_tail=False) ] context_tokens = padded_tensor(items=context_tokens, pad_idx=self.pad_token_idx, max_len=self.context_truncate, pad_tail=False) context_entities = [ truncate(data['context_entities'], self.entity_truncate, truncate_tail=False) ] context_words = [ truncate(data['context_words'], self.word_truncate, truncate_tail=False) ] return (context_tokens, context_tokens, context_tokens, context_tokens, padded_tensor(context_entities, self.pad_entity_idx, pad_tail=False), padded_tensor(context_words, self.pad_word_idx, pad_tail=False), None)
def rec_interact(self, data): context = [self._process_rec_context(data['context_tokens'])] if 'interaction_history' in data: context_items = data['interaction_history'] + data['context_items'] else: context_items = data['context_items'] input_ids, input_mask, sample_negs = self._process_history( context_items) input_ids, input_mask, sample_negs = [input_ids], [input_mask ], [sample_negs] context = padded_tensor(context, self.pad_token_idx, max_len=self.context_truncate) mask = (context != 0).long() return (context, mask, padded_tensor(input_ids, pad_idx=self.pad_token_idx, pad_tail=False, max_len=self.item_truncate), None, padded_tensor(input_mask, pad_idx=self.pad_token_idx, pad_tail=False, max_len=self.item_truncate), padded_tensor(sample_negs, pad_idx=self.pad_token_idx, pad_tail=False, max_len=self.item_truncate), None)
def conv_batchify(self, batch): max_utterance_length = max( [max(conversation['utterance_lengths']) for conversation in batch]) max_response_length = max( [len(conversation['response']) for conversation in batch]) max_utterance_length = max(max_utterance_length, max_response_length) max_context_length = max( [conversation['context_length'] for conversation in batch]) batch_context = [] batch_context_length = [] batch_utterance_lengths = [] batch_request = [] # tensor batch_request_length = [] batch_response = [] for conversation in batch: padded_context = padded_tensor(conversation['context_tokens'], pad_idx=self.pad_token_idx, pad_tail=True, max_len=max_utterance_length) if len(conversation['context_tokens']) < max_context_length: pad_tensor = padded_context.new_full( (max_context_length - len(conversation['context_tokens']), max_utterance_length), self.pad_token_idx) padded_context = torch.cat((padded_context, pad_tensor), 0) batch_context.append(padded_context) batch_context_length.append(conversation['context_length']) batch_utterance_lengths.append( conversation['utterance_lengths'] + [0] * (max_context_length - len(conversation['context_tokens']))) request = conversation['request'] batch_request_length.append(len(request)) batch_request.append(request) response = copy(conversation['response']) # replace '^\d{5,6}$' by '__item__' for i in range(len(response)): if movie_pattern.match(self.ind2tok[response[i]]): response[i] = self.item_token_idx batch_response.append(response) context = torch.stack(batch_context, dim=0) request = padded_tensor(batch_request, self.pad_token_idx, pad_tail=True, max_len=max_utterance_length) response = padded_tensor(batch_response, self.pad_token_idx, pad_tail=True, max_len=max_utterance_length) # (bs, utt_len) return { 'context': context, 'context_lengths': torch.tensor(batch_context_length), 'utterance_lengths': torch.tensor(batch_utterance_lengths), 'request': request, 'request_lengths': torch.tensor(batch_request_length), 'response': response }
def conv_batchify(self, batch): batch_context_tokens = [] batch_context_entities = [] batch_context_words = [] batch_response = [] for conv_dict in batch: batch_context_tokens.append( truncate(merge_utt(conv_dict['context_tokens']), self.context_truncate, truncate_tail=False)) batch_context_entities.append( truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False)) batch_context_words.append( truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False)) batch_response.append( add_start_end_token_idx(truncate(conv_dict['response'], self.response_truncate - 2), start_token_idx=self.start_token_idx, end_token_idx=self.end_token_idx)) return (padded_tensor(batch_context_tokens, self.pad_token_idx, pad_tail=False), padded_tensor(batch_context_entities, self.pad_entity_idx, pad_tail=False), padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False), padded_tensor(batch_response, self.pad_token_idx))
def conv_batchify(self, batch): batch_context_tokens = [] batch_context_entities = [] batch_context_words = [] batch_response = [] flag = False batch_all_movies = [] for conv_dict in batch: temp = add_start_end_token_idx( truncate(conv_dict['response'], self.response_truncate - 2), start_token_idx=self.start_token_idx, end_token_idx=self.end_token_idx) if temp.count(self.replace_token_idx) != 0: flag = True batch_context_tokens.append( truncate(merge_utt(conv_dict['context_tokens']), self.context_truncate, truncate_tail=False)) batch_context_entities.append( truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False)) batch_context_words.append( truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False)) batch_response.append( add_start_end_token_idx(truncate(conv_dict['response'], self.response_truncate - 2), start_token_idx=self.start_token_idx, end_token_idx=self.end_token_idx)) batch_all_movies.append( truncate( conv_dict['items'], temp.count(self.replace_token_idx), truncate_tail=False)) #only use movies, not all entities. if flag == False: # zero slot in a batch return False return (padded_tensor(batch_context_tokens, self.pad_token_idx, pad_tail=False), padded_tensor(batch_context_entities, self.pad_entity_idx, pad_tail=False), padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False), padded_tensor(batch_response, self.pad_token_idx), padded_tensor(batch_all_movies, self.pad_entity_idx, pad_tail=False))
def rec_batchify(self, batch): batch_context = [] batch_movie_id = [] batch_input_ids = [] batch_target_pos = [] batch_input_mask = [] batch_sample_negs = [] for conv_dict in batch: context = self._process_rec_context(conv_dict['context_tokens']) batch_context.append(context) item_id = conv_dict['item'] batch_movie_id.append(item_id) if 'interaction_history' in conv_dict: context_items = conv_dict['interaction_history'] + conv_dict[ 'context_items'] else: context_items = conv_dict['context_items'] input_ids, target_pos, input_mask, sample_negs = self._process_history( context_items, item_id) batch_input_ids.append(input_ids) batch_target_pos.append(target_pos) batch_input_mask.append(input_mask) batch_sample_negs.append(sample_negs) batch_context = padded_tensor(batch_context, self.pad_token_idx, max_len=self.context_truncate) batch_mask = (batch_context != 0).long() return (batch_context, batch_mask, padded_tensor(batch_input_ids, pad_idx=self.pad_token_idx, pad_tail=False, max_len=self.item_truncate), padded_tensor(batch_target_pos, pad_idx=self.pad_token_idx, pad_tail=False, max_len=self.item_truncate), padded_tensor(batch_input_mask, pad_idx=self.pad_token_idx, pad_tail=False, max_len=self.item_truncate), padded_tensor(batch_sample_negs, pad_idx=self.pad_token_idx, pad_tail=False, max_len=self.item_truncate), torch.tensor(batch_movie_id))
def conv_batchify(self, batch): """get batch and corresponding roles """ batch_roles = [] batch_context_tokens = [] batch_response = [] for conv_dict in batch: batch_roles.append(0 if conv_dict['role'] == 'Seeker' else 1) context_tokens = [ utter + [self.conv_bos_id] for utter in conv_dict['context_tokens'] ] context_tokens[-1] = context_tokens[-1][:-1] batch_context_tokens.append( truncate(merge_utt(context_tokens), max_length=self.context_truncate, truncate_tail=False), ) batch_response.append( add_start_end_token_idx(truncate( conv_dict['response'], max_length=self.response_truncate - 2), start_token_idx=self.start_token_idx, end_token_idx=self.end_token_idx)) batch_context_tokens = padded_tensor(items=batch_context_tokens, pad_idx=self.pad_token_idx, max_len=self.context_truncate, pad_tail=False) batch_response = padded_tensor(batch_response, pad_idx=self.pad_token_idx, max_len=self.response_truncate, pad_tail=True) batch_input_ids = torch.cat((batch_context_tokens, batch_response), dim=1) batch_roles = torch.tensor(batch_roles) return (batch_roles, batch_input_ids, batch_context_tokens, batch_response)
def rec_batchify(self, batch): batch_context_entities = [] batch_context_words = [] batch_item = [] for conv_dict in batch: batch_context_entities.append( truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False)) batch_context_words.append( truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False)) batch_item.append(conv_dict['item']) return (padded_tensor(batch_context_entities, self.pad_entity_idx, pad_tail=False), padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False), get_onehot(batch_context_entities, self.n_entity), torch.tensor(batch_item, dtype=torch.long))
def pretrain_batchify(self, batch): batch_context_entities = [] batch_context_words = [] for conv_dict in batch: batch_context_entities.append( truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False)) batch_context_words.append( truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False)) return (padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False), get_onehot(batch_context_entities, self.n_entity))
def rec_batchify(self, batch): batch_context = [] batch_movie_id = [] for conv_dict in batch: context = self._process_rec_context(conv_dict['context_tokens']) batch_context.append(context) item_id = conv_dict['item'] batch_movie_id.append(item_id) batch_context = padded_tensor(batch_context, self.pad_token_idx, max_len=self.context_truncate) batch_mask = (batch_context != self.pad_token_idx).long() return (batch_context, batch_mask, torch.tensor(batch_movie_id))
def policy_batchify(self, batch): batch_context = [] batch_context_policy = [] batch_user_profile = [] batch_target = [] for conv_dict in batch: final_topic = conv_dict['final'] final_topic = [[ self.tok2ind.get(token, self.unk_token_idx) for token in self.ind2topic[topic_id] ] for topic_id in final_topic[1]] final_topic = merge_utt(final_topic, self.word_split_idx, False, self.sep_id) context = conv_dict['context_tokens'] context = merge_utt(context, self.sent_split_idx, False, self.sep_id) context += final_topic context = add_start_end_token_idx(truncate( context, max_length=self.context_truncate - 1, truncate_tail=False), start_token_idx=self.cls_id) batch_context.append(context) # [topic, topic, ..., topic] context_policy = [] for policies_one_turn in conv_dict['context_policy']: if len(policies_one_turn) != 0: for policy in policies_one_turn: for topic_id in policy[1]: if topic_id != self.pad_topic_idx: policy = [] for token in self.ind2topic[topic_id]: policy.append( self.tok2ind.get( token, self.unk_token_idx)) context_policy.append(policy) context_policy = merge_utt(context_policy, self.word_split_idx, False) context_policy = add_start_end_token_idx( context_policy, start_token_idx=self.cls_id, end_token_idx=self.sep_id) context_policy += final_topic batch_context_policy.append(context_policy) batch_user_profile.extend(conv_dict['user_profile']) batch_target.append(conv_dict['target_topic']) batch_context = padded_tensor(batch_context, pad_idx=self.pad_token_idx, pad_tail=True, max_len=self.context_truncate) batch_cotnext_mask = (batch_context != 0).long() batch_context_policy = padded_tensor(batch_context_policy, pad_idx=self.pad_token_idx, pad_tail=True) batch_context_policy_mask = (batch_context_policy != 0).long() batch_user_profile = padded_tensor(batch_user_profile, pad_idx=self.pad_token_idx, pad_tail=True) batch_user_profile_mask = (batch_user_profile != 0).long() batch_target = torch.tensor(batch_target, dtype=torch.long) return (batch_context, batch_cotnext_mask, batch_context_policy, batch_context_policy_mask, batch_user_profile, batch_user_profile_mask, batch_target)
def conv_batchify(self, batch): batch_context_tokens = [] batch_enhanced_context_tokens = [] batch_response = [] batch_context_entities = [] batch_context_words = [] for conv_dict in batch: context_tokens = [ utter + [self.conv_bos_id] for utter in conv_dict['context_tokens'] ] context_tokens[-1] = context_tokens[-1][:-1] batch_context_tokens.append( truncate(merge_utt(context_tokens), max_length=self.context_truncate, truncate_tail=False), ) batch_response.append( add_start_end_token_idx(truncate( conv_dict['response'], max_length=self.response_truncate - 2), start_token_idx=self.start_token_idx, end_token_idx=self.end_token_idx)) batch_context_entities.append( truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False)) batch_context_words.append( truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False)) enhanced_topic = [] if 'target' in conv_dict: for target_policy in conv_dict['target']: topic_variable = target_policy[1] if isinstance(topic_variable, list): for topic in topic_variable: enhanced_topic.append(topic) enhanced_topic = [[ self.tok2ind.get(token, self.unk_token_idx) for token in self.ind2topic[topic_id] ] for topic_id in enhanced_topic] enhanced_topic = merge_utt(enhanced_topic, self.word_split_idx, False, self.sent_split_idx) enhanced_movie = [] if 'items' in conv_dict: for movie_id in conv_dict['items']: enhanced_movie.append(movie_id) enhanced_movie = [[ self.tok2ind.get(token, self.unk_token_idx) for token in self.id2entity[movie_id].split('(')[0] ] for movie_id in enhanced_movie] enhanced_movie = truncate(merge_utt(enhanced_movie, self.word_split_idx, self.sent_split_idx), self.item_truncate, truncate_tail=False) if len(enhanced_movie) != 0: enhanced_context_tokens = enhanced_movie + truncate( batch_context_tokens[-1], max_length=self.context_truncate - len(enhanced_movie), truncate_tail=False) elif len(enhanced_topic) != 0: enhanced_context_tokens = enhanced_topic + truncate( batch_context_tokens[-1], max_length=self.context_truncate - len(enhanced_topic), truncate_tail=False) else: enhanced_context_tokens = batch_context_tokens[-1] batch_enhanced_context_tokens.append(enhanced_context_tokens) batch_context_tokens = padded_tensor(items=batch_context_tokens, pad_idx=self.pad_token_idx, max_len=self.context_truncate, pad_tail=False) batch_response = padded_tensor(batch_response, pad_idx=self.pad_token_idx, max_len=self.response_truncate, pad_tail=True) batch_input_ids = torch.cat((batch_context_tokens, batch_response), dim=1) batch_enhanced_context_tokens = padded_tensor( items=batch_enhanced_context_tokens, pad_idx=self.pad_token_idx, max_len=self.context_truncate, pad_tail=False) batch_enhanced_input_ids = torch.cat( (batch_enhanced_context_tokens, batch_response), dim=1) return (batch_enhanced_input_ids, batch_enhanced_context_tokens, batch_input_ids, batch_context_tokens, padded_tensor(batch_context_entities, self.pad_entity_idx, pad_tail=False), padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False), batch_response)