def conv_interact(self, data): context_tokens = [ utter + [self.conv_bos_id] for utter in data['context_tokens'] ] context_tokens[-1] = context_tokens[-1][:-1] context_tokens = [ truncate(merge_utt(context_tokens), max_length=self.context_truncate, truncate_tail=False) ] context_tokens = padded_tensor(items=context_tokens, pad_idx=self.pad_token_idx, max_len=self.context_truncate, pad_tail=False) context_entities = [ truncate(data['context_entities'], self.entity_truncate, truncate_tail=False) ] context_words = [ truncate(data['context_words'], self.word_truncate, truncate_tail=False) ] return (context_tokens, context_tokens, context_tokens, context_tokens, padded_tensor(context_entities, self.pad_entity_idx, pad_tail=False), padded_tensor(context_words, self.pad_word_idx, pad_tail=False), None)
def conv_batchify(self, batch): batch_context_tokens = [] batch_context_entities = [] batch_context_words = [] batch_response = [] for conv_dict in batch: batch_context_tokens.append( truncate(merge_utt(conv_dict['context_tokens']), self.context_truncate, truncate_tail=False)) batch_context_entities.append( truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False)) batch_context_words.append( truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False)) batch_response.append( add_start_end_token_idx(truncate(conv_dict['response'], self.response_truncate - 2), start_token_idx=self.start_token_idx, end_token_idx=self.end_token_idx)) return (padded_tensor(batch_context_tokens, self.pad_token_idx, pad_tail=False), padded_tensor(batch_context_entities, self.pad_entity_idx, pad_tail=False), padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False), padded_tensor(batch_response, self.pad_token_idx))
def conv_batchify(self, batch): batch_context_tokens = [] batch_context_entities = [] batch_context_words = [] batch_response = [] flag = False batch_all_movies = [] for conv_dict in batch: temp = add_start_end_token_idx( truncate(conv_dict['response'], self.response_truncate - 2), start_token_idx=self.start_token_idx, end_token_idx=self.end_token_idx) if temp.count(self.replace_token_idx) != 0: flag = True batch_context_tokens.append( truncate(merge_utt(conv_dict['context_tokens']), self.context_truncate, truncate_tail=False)) batch_context_entities.append( truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False)) batch_context_words.append( truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False)) batch_response.append( add_start_end_token_idx(truncate(conv_dict['response'], self.response_truncate - 2), start_token_idx=self.start_token_idx, end_token_idx=self.end_token_idx)) batch_all_movies.append( truncate( conv_dict['items'], temp.count(self.replace_token_idx), truncate_tail=False)) #only use movies, not all entities. if flag == False: # zero slot in a batch return False return (padded_tensor(batch_context_tokens, self.pad_token_idx, pad_tail=False), padded_tensor(batch_context_entities, self.pad_entity_idx, pad_tail=False), padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False), padded_tensor(batch_response, self.pad_token_idx), padded_tensor(batch_all_movies, self.pad_entity_idx, pad_tail=False))
def _process_rec_context(self, context_tokens): compact_context = [] for i, utterance in enumerate(context_tokens): if i != 0: utterance.insert(0, self.sent_split_idx) compact_context.append(utterance) compat_context = truncate(merge_utt(compact_context), self.context_truncate - 2, truncate_tail=False) compat_context = add_start_end_token_idx(compat_context, self.start_token_idx, self.end_token_idx) return compat_context
def conv_batchify(self, batch): """get batch and corresponding roles """ batch_roles = [] batch_context_tokens = [] batch_response = [] for conv_dict in batch: batch_roles.append(0 if conv_dict['role'] == 'Seeker' else 1) context_tokens = [ utter + [self.conv_bos_id] for utter in conv_dict['context_tokens'] ] context_tokens[-1] = context_tokens[-1][:-1] batch_context_tokens.append( truncate(merge_utt(context_tokens), max_length=self.context_truncate, truncate_tail=False), ) batch_response.append( add_start_end_token_idx(truncate( conv_dict['response'], max_length=self.response_truncate - 2), start_token_idx=self.start_token_idx, end_token_idx=self.end_token_idx)) batch_context_tokens = padded_tensor(items=batch_context_tokens, pad_idx=self.pad_token_idx, max_len=self.context_truncate, pad_tail=False) batch_response = padded_tensor(batch_response, pad_idx=self.pad_token_idx, max_len=self.response_truncate, pad_tail=True) batch_input_ids = torch.cat((batch_context_tokens, batch_response), dim=1) batch_roles = torch.tensor(batch_roles) return (batch_roles, batch_input_ids, batch_context_tokens, batch_response)
def policy_batchify(self, batch): batch_context = [] batch_context_policy = [] batch_user_profile = [] batch_target = [] for conv_dict in batch: final_topic = conv_dict['final'] final_topic = [[ self.tok2ind.get(token, self.unk_token_idx) for token in self.ind2topic[topic_id] ] for topic_id in final_topic[1]] final_topic = merge_utt(final_topic, self.word_split_idx, False, self.sep_id) context = conv_dict['context_tokens'] context = merge_utt(context, self.sent_split_idx, False, self.sep_id) context += final_topic context = add_start_end_token_idx(truncate( context, max_length=self.context_truncate - 1, truncate_tail=False), start_token_idx=self.cls_id) batch_context.append(context) # [topic, topic, ..., topic] context_policy = [] for policies_one_turn in conv_dict['context_policy']: if len(policies_one_turn) != 0: for policy in policies_one_turn: for topic_id in policy[1]: if topic_id != self.pad_topic_idx: policy = [] for token in self.ind2topic[topic_id]: policy.append( self.tok2ind.get( token, self.unk_token_idx)) context_policy.append(policy) context_policy = merge_utt(context_policy, self.word_split_idx, False) context_policy = add_start_end_token_idx( context_policy, start_token_idx=self.cls_id, end_token_idx=self.sep_id) context_policy += final_topic batch_context_policy.append(context_policy) batch_user_profile.extend(conv_dict['user_profile']) batch_target.append(conv_dict['target_topic']) batch_context = padded_tensor(batch_context, pad_idx=self.pad_token_idx, pad_tail=True, max_len=self.context_truncate) batch_cotnext_mask = (batch_context != 0).long() batch_context_policy = padded_tensor(batch_context_policy, pad_idx=self.pad_token_idx, pad_tail=True) batch_context_policy_mask = (batch_context_policy != 0).long() batch_user_profile = padded_tensor(batch_user_profile, pad_idx=self.pad_token_idx, pad_tail=True) batch_user_profile_mask = (batch_user_profile != 0).long() batch_target = torch.tensor(batch_target, dtype=torch.long) return (batch_context, batch_cotnext_mask, batch_context_policy, batch_context_policy_mask, batch_user_profile, batch_user_profile_mask, batch_target)
def conv_batchify(self, batch): batch_context_tokens = [] batch_enhanced_context_tokens = [] batch_response = [] batch_context_entities = [] batch_context_words = [] for conv_dict in batch: context_tokens = [ utter + [self.conv_bos_id] for utter in conv_dict['context_tokens'] ] context_tokens[-1] = context_tokens[-1][:-1] batch_context_tokens.append( truncate(merge_utt(context_tokens), max_length=self.context_truncate, truncate_tail=False), ) batch_response.append( add_start_end_token_idx(truncate( conv_dict['response'], max_length=self.response_truncate - 2), start_token_idx=self.start_token_idx, end_token_idx=self.end_token_idx)) batch_context_entities.append( truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False)) batch_context_words.append( truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False)) enhanced_topic = [] if 'target' in conv_dict: for target_policy in conv_dict['target']: topic_variable = target_policy[1] if isinstance(topic_variable, list): for topic in topic_variable: enhanced_topic.append(topic) enhanced_topic = [[ self.tok2ind.get(token, self.unk_token_idx) for token in self.ind2topic[topic_id] ] for topic_id in enhanced_topic] enhanced_topic = merge_utt(enhanced_topic, self.word_split_idx, False, self.sent_split_idx) enhanced_movie = [] if 'items' in conv_dict: for movie_id in conv_dict['items']: enhanced_movie.append(movie_id) enhanced_movie = [[ self.tok2ind.get(token, self.unk_token_idx) for token in self.id2entity[movie_id].split('(')[0] ] for movie_id in enhanced_movie] enhanced_movie = truncate(merge_utt(enhanced_movie, self.word_split_idx, self.sent_split_idx), self.item_truncate, truncate_tail=False) if len(enhanced_movie) != 0: enhanced_context_tokens = enhanced_movie + truncate( batch_context_tokens[-1], max_length=self.context_truncate - len(enhanced_movie), truncate_tail=False) elif len(enhanced_topic) != 0: enhanced_context_tokens = enhanced_topic + truncate( batch_context_tokens[-1], max_length=self.context_truncate - len(enhanced_topic), truncate_tail=False) else: enhanced_context_tokens = batch_context_tokens[-1] batch_enhanced_context_tokens.append(enhanced_context_tokens) batch_context_tokens = padded_tensor(items=batch_context_tokens, pad_idx=self.pad_token_idx, max_len=self.context_truncate, pad_tail=False) batch_response = padded_tensor(batch_response, pad_idx=self.pad_token_idx, max_len=self.response_truncate, pad_tail=True) batch_input_ids = torch.cat((batch_context_tokens, batch_response), dim=1) batch_enhanced_context_tokens = padded_tensor( items=batch_enhanced_context_tokens, pad_idx=self.pad_token_idx, max_len=self.context_truncate, pad_tail=False) batch_enhanced_input_ids = torch.cat( (batch_enhanced_context_tokens, batch_response), dim=1) return (batch_enhanced_input_ids, batch_enhanced_context_tokens, batch_input_ids, batch_context_tokens, padded_tensor(batch_context_entities, self.pad_entity_idx, pad_tail=False), padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False), batch_response)