def collate(data_list): """ collate """ batch = Pack() # 手写各个结构 # num_src batch['num_src'] = list2tensor([x['num_src'] for x in data_list]) # num_tgt_input batch['num_tgt_input'] = list2tensor( [x['num_tgt_input'] for x in data_list]) # tgt_output batch['tgt_output'] = list2tensor( [x['tgt_output'] for x in data_list]) batch['tgt_emo'] = list2tensor([x['tgt_emo'] for x in data_list]) # mask batch['mask'] = list2tensor([x['mask'] for x in data_list]) batch['raw_src'] = [x['raw_src'] for x in data_list] batch['raw_tgt'] = [x['raw_tgt'] for x in data_list] if 'id' in data_list[0].keys(): batch['id'] = [x['id'] for x in data_list] if device >= 0: batch = batch.cuda(device=device) return batch
def collate(data_list): batch = Pack() for key in data_list[0].keys(): batch[key] = list2tensor([x[key] for x in data_list]) if device >= 0: batch = batch.cuda(device=device) return batch
def create_situation_batch(situation_list): """ create_situation_batch """ situation_list = list(map(lambda x: x[1:-1], situation_list)) situation_batches = list2tensor(situation_list) return situation_batches
def create_user_profile_batch(user_profile_list): """ create_user_profile_batch """ user_profile_list = list( map(lambda xs: [x[1:-1] for x in xs], user_profile_list)) user_profile_batches = list2tensor(user_profile_list) return user_profile_batches
def collate(data_list): """ collate """ data_list1, data_list2 = zip(*data_list) batch1 = Pack() batch2 = Pack() data_list1 = list(data_list1) data_list2 = list(data_list2) for key in data_list1[0].keys(): batch1[key] = list2tensor([x[key] for x in data_list1]) if device >= 0: batch1 = batch1.cuda(device=device) for key in list(data_list2)[0].keys(): batch2[key] = list2tensor([x[key] for x in data_list2]) if device >= 0: batch2 = batch2.cuda(device=device) return batch1, batch2
def interact(self, src, cue=None): if src == "": return None inputs = Pack() src = self.src_field.numericalize([src]) inputs.add(src=list2tensor(src)) if cue is not None: cue = self.cue_field.numericalize([cue]) inputs.add(cue=list2tensor(cue)) if self.use_gpu: inputs = inputs.cuda() _, preds, _, _ = self.forward(inputs=inputs, num_candidates=1) pred = self.tgt_field.denumericalize(preds[0][0]) return pred
def collate( data_list): # data_list的长度就是一个batch_size,每个元素都是__getitem__得到的 """ collate """ batch = Pack() for key in data_list[0].keys(): # keys(): src, tgt, cue batch[key] = list2tensor([x[key] for x in data_list ]) # 所有的src, tgt, cue分别整合在一起 if device >= 0: batch = batch.cuda(device=device) return batch
def collate(data_list): """ collate """ batch = Pack() # batch is a dict for key in data_list[0].keys(): # data_list: a list of dict # so one sample is one dict batch[key] = list2tensor([x[key] for x in data_list]) if device >= 0: batch = batch.cuda(device=device) return batch
def create_turn_batch(data_list): """ create_turn_batch """ turn_batches = [] for data_dict in data_list: batch = Pack() for key in data_dict.keys(): if key in ['src', 'tgt', 'ptr_index', 'kb_index']: batch[key] = list2tensor([x for x in data_dict[key]]) else: batch[key] = data_dict[key] turn_batches.append(batch) return turn_batches
def collate(data_list): """ collate """ batch = Pack() for key in data_list[0].keys(): if key == 'topic': continue batch[key] = list2tensor([x[key] for x in data_list]) batch_bow = [] for x in data_list: v = torch.zeros(bow_vocab_size, dtype=torch.float) x_bow = x['topic'] # dict for w, f in x_bow: v[w] += f batch_bow.append(v) batch['bow'] = torch.stack(batch_bow) if device >= 0: batch = batch.cuda(device=device) return batch
def collate(data_list): """ collate --- data_list: List[Dict] """ batch = Pack() for key in data_list[0].keys(): batch[key] = list2tensor([x[key] for x in data_list]) if device >= 0: batch = batch.cuda(device=device) # copy mechanism prepare raw_src = [x['raw_src'].split() for x in data_list] token2idx, idx2token, batch_pos_idx_map, idx2idx_mapping \ = build_copy_mapping(raw_src, vocab) batch['token2idx'] = token2idx batch['idx2token'] = idx2token batch['batch_pos_idx_map'] = batch_pos_idx_map batch['idx2idx_mapping'] = idx2idx_mapping batch['output'] = '???' return batch
def create_kb_batch(kb_list): """ create_kb_batch """ kb_batches = list2tensor(kb_list) return kb_batches