def collate(data_list):
            """
            collate
            """
            batch = Pack()
            # 手写各个结构
            # num_src
            batch['num_src'] = list2tensor([x['num_src'] for x in data_list])
            # num_tgt_input
            batch['num_tgt_input'] = list2tensor(
                [x['num_tgt_input'] for x in data_list])
            #  tgt_output
            batch['tgt_output'] = list2tensor(
                [x['tgt_output'] for x in data_list])
            batch['tgt_emo'] = list2tensor([x['tgt_emo'] for x in data_list])
            # mask
            batch['mask'] = list2tensor([x['mask'] for x in data_list])
            batch['raw_src'] = [x['raw_src'] for x in data_list]
            batch['raw_tgt'] = [x['raw_tgt'] for x in data_list]

            if 'id' in data_list[0].keys():
                batch['id'] = [x['id'] for x in data_list]

            if device >= 0:
                batch = batch.cuda(device=device)
            return batch
Beispiel #2
0
 def collate(data_list):
     batch = Pack()
     for key in data_list[0].keys():
         batch[key] = list2tensor([x[key] for x in data_list])
     if device >= 0:
         batch = batch.cuda(device=device)
     return batch
Beispiel #3
0
def create_situation_batch(situation_list):
    """
    create_situation_batch
    """
    situation_list = list(map(lambda x: x[1:-1], situation_list))
    situation_batches = list2tensor(situation_list)
    return situation_batches
Beispiel #4
0
def create_user_profile_batch(user_profile_list):
    """
    create_user_profile_batch
    """
    user_profile_list = list(
        map(lambda xs: [x[1:-1] for x in xs], user_profile_list))
    user_profile_batches = list2tensor(user_profile_list)
    return user_profile_batches
Beispiel #5
0
 def collate(data_list):
     """
     collate
     """
     data_list1, data_list2 = zip(*data_list)
     batch1 = Pack()
     batch2 = Pack()
     data_list1 = list(data_list1)
     data_list2 = list(data_list2)
     for key in data_list1[0].keys():
         batch1[key] = list2tensor([x[key] for x in data_list1])
     if device >= 0:
         batch1 = batch1.cuda(device=device)
     for key in list(data_list2)[0].keys():
         batch2[key] = list2tensor([x[key] for x in data_list2])
     if device >= 0:
         batch2 = batch2.cuda(device=device)
     return batch1, batch2
Beispiel #6
0
    def interact(self, src, cue=None):
        if src == "":
            return None

        inputs = Pack()
        src = self.src_field.numericalize([src])
        inputs.add(src=list2tensor(src))

        if cue is not None:
            cue = self.cue_field.numericalize([cue])
            inputs.add(cue=list2tensor(cue))
        if self.use_gpu:
            inputs = inputs.cuda()
        _, preds, _, _ = self.forward(inputs=inputs, num_candidates=1)

        pred = self.tgt_field.denumericalize(preds[0][0])

        return pred
 def collate(
         data_list):  # data_list的长度就是一个batch_size,每个元素都是__getitem__得到的
     """
     collate
     """
     batch = Pack()
     for key in data_list[0].keys():  # keys(): src, tgt, cue
         batch[key] = list2tensor([x[key] for x in data_list
                                   ])  # 所有的src, tgt, cue分别整合在一起
     if device >= 0:
         batch = batch.cuda(device=device)
     return batch
 def collate(data_list):
     """
     collate
     """
     batch = Pack()
     # batch is a dict
     for key in data_list[0].keys():
         # data_list: a list of dict
         # so one sample is one dict
         batch[key] = list2tensor([x[key] for x in data_list])
     if device >= 0:
         batch = batch.cuda(device=device)
     return batch
Beispiel #9
0
def create_turn_batch(data_list):
    """
    create_turn_batch
    """
    turn_batches = []
    for data_dict in data_list:
        batch = Pack()
        for key in data_dict.keys():
            if key in ['src', 'tgt', 'ptr_index', 'kb_index']:
                batch[key] = list2tensor([x for x in data_dict[key]])
            else:
                batch[key] = data_dict[key]
        turn_batches.append(batch)

    return turn_batches
Beispiel #10
0
        def collate(data_list):
            """
            collate
            """
            batch = Pack()
            for key in data_list[0].keys():
                if key == 'topic':
                    continue
                batch[key] = list2tensor([x[key] for x in data_list])

            batch_bow = []
            for x in data_list:
                v = torch.zeros(bow_vocab_size, dtype=torch.float)
                x_bow = x['topic']  # dict
                for w, f in x_bow:
                    v[w] += f
                batch_bow.append(v)
            batch['bow'] = torch.stack(batch_bow)

            if device >= 0:
                batch = batch.cuda(device=device)
            return batch
Beispiel #11
0
        def collate(data_list):
            """
            collate
            ---
            data_list: List[Dict]
            """
            batch = Pack()
            for key in data_list[0].keys():
                batch[key] = list2tensor([x[key] for x in data_list])
            if device >= 0:
                batch = batch.cuda(device=device)

            # copy mechanism prepare
            raw_src = [x['raw_src'].split() for x in data_list]
            token2idx, idx2token, batch_pos_idx_map, idx2idx_mapping \
                = build_copy_mapping(raw_src, vocab)
            batch['token2idx'] = token2idx
            batch['idx2token'] = idx2token
            batch['batch_pos_idx_map'] = batch_pos_idx_map
            batch['idx2idx_mapping'] = idx2idx_mapping
            batch['output'] = '???'

            return batch
Beispiel #12
0
def create_kb_batch(kb_list):
    """
    create_kb_batch
    """
    kb_batches = list2tensor(kb_list)
    return kb_batches