示例#1
0
文件: amr.py 项目: lei1993/HanLP
def make_batch_for_bart(augmented_concept,
                        ret,
                        tokenizer,
                        device,
                        training=True):
    token_field = 'concept'
    tokenizer = TransformerSequenceTokenizer(tokenizer.tokenizer,
                                             token_field,
                                             cls_is_bos=True,
                                             sep_is_eos=None)
    encodings = [
        tokenizer({token_field: x[:-1] if training else x})
        for x in augmented_concept
    ]
    ret.update(merge_list_of_dict(encodings))
    decoder_mask = []
    max_seq_len = len(max(ret['concept_input_ids'], key=len))
    last_concept_offset = []
    for spans, concepts in zip(ret['concept_token_span'], augmented_concept):
        mask = ~SelfAttentionMask.get_mask(
            max_seq_len, device, ret_parameter=False)
        for group in spans:
            for i in range(len(group)):
                for j in range(i + 1, len(group)):
                    mask[group[i], group[j]] = True
        decoder_mask.append(mask)
        last_concept_offset.append(len(concepts) - 1)
    ret['decoder_mask'] = torch.stack(decoder_mask)
    if not training:
        ret['last_concept_offset'] = torch.tensor(last_concept_offset,
                                                  device=device,
                                                  dtype=torch.long)
    subtoken_to_tensor(token_field, ret)
示例#2
0
文件: amr.py 项目: lei1993/HanLP
 def transform_batch(self,
                     batch: Dict[str, Any],
                     results: Dict[str, Any] = None,
                     cls_is_bos=False,
                     sep_is_eos=False) -> Dict[str, Any]:
     batch = super().transform_batch(batch, results, cls_is_bos, sep_is_eos)
     batch['lemma'] = [[CLS] + x for x in results['lem']]
     copy_seq = merge_list_of_dict([
         get_concepts({
             'token': t[1:],
             'lemma': l[1:]
         }, self.vocabs.predictable_concept)
         for t, l in zip(batch['token'], batch['lemma'])
     ])
     copy_seq.pop('token')
     copy_seq.pop('lemma')
     batch.update(copy_seq)
     ret = batchify(batch,
                    self.vocabs,
                    device=batch['token_input_ids'].device)
     return ret
示例#3
0
文件: hotpotqa.py 项目: lei1993/HanLP
def hotpotqa_collate_fn(samples):
    batch = merge_list_of_dict(samples)
    max_seq_len = len(max([x['graph'] for x in samples], key=len))
    arc = torch.zeros([len(samples), max_seq_len, max_seq_len])
    token_offset = torch.zeros([len(samples), max_seq_len], dtype=torch.long)
    src_mask = torch.zeros([len(samples), max_seq_len], dtype=torch.bool)
    sp_candidate_mask = torch.zeros([len(samples), max_seq_len], dtype=torch.bool)
    sp_label = torch.zeros([len(samples), max_seq_len], dtype=torch.float)
    # sp = torch.zeros([len(samples), max_seq_len], dtype=torch.bool)
    tokens = []
    offset = 0
    for i, sample in enumerate(samples):
        graph = sample['graph']
        for j, u in enumerate(graph):
            u: Vertex = u
            for v in u.to:
                v: Vertex = v
                arc[i, v.id, u.id] = 1
                arc[i, u.id, v.id] = 1
            # record each vertex's token offset
            token_offset[i, u.id] = offset
            src_mask[i, u.id] = True
            sp_candidate_mask[i, u.id] = u.is_sp_root_candidate()
            sp_label[i, u.id] = u.is_sp_root()
            offset += 1
        tokens.extend(sample['token_id'])
    seq_lengths = torch.LongTensor(list(map(len, tokens)))
    tokens = [torch.LongTensor(x) for x in tokens]
    tokens = pad_sequence(tokens, batch_first=True)
    batch['adj'] = arc
    batch['tokens'] = tokens
    batch['src_mask'] = src_mask
    batch['seq_lengths'] = seq_lengths
    batch['token_offset'] = token_offset
    batch['sp_candidate_mask'] = sp_candidate_mask
    batch['sp_label'] = sp_label
    return batch
示例#4
0
 def collate_fn(self, samples):
     return merge_list_of_dict(samples)
示例#5
0
文件: amr.py 项目: lei1993/HanLP
def make_batch_for_squeeze(data, augmented_concept, tokenizer, device, ret):
    token_field = 'token_and_concept'
    attention_mask = []
    token_and_concept = [
        t + [tokenizer.sep_token] + c
        for t, c in zip(data['token'], augmented_concept)
    ]
    encodings = [tokenizer({token_field: x}) for x in token_and_concept]
    ret.update(merge_list_of_dict(encodings))
    max_input_len = len(max(ret[f'{token_field}_input_ids'], key=len))
    concept_mask = []
    token_mask = []
    token_type_ids = []
    snt_len = []
    last_concept_offset = []
    for tokens, concepts, input_ids, spans in zip(
            data['token'], augmented_concept,
            ret['token_and_concept_input_ids'],
            ret['token_and_concept_token_span']):
        raw_sent_len = len(tokens) + 1  # for [SEP]
        raw_concept_len = len(concepts)
        if concepts[-1] == END:
            concept_mask.append([False] * raw_sent_len + [True] *
                                (raw_concept_len - 1) +
                                [False])  # skip END concept
        else:
            concept_mask.append([False] * raw_sent_len +
                                [True] * raw_concept_len)
        token_mask.append([False] + [True] * (raw_sent_len - 2) + [False] *
                          (raw_concept_len + 1))
        assert len(concept_mask) == len(token_mask)
        snt_len.append(raw_sent_len - 2)  # skip [CLS] and [SEP]
        sent_len = input_ids.index(tokenizer.tokenizer.sep_token_id) + 1
        concept_len = len(input_ids) - sent_len
        mask = torch.zeros((max_input_len, max_input_len), dtype=torch.bool)
        mask[:sent_len + concept_len, :sent_len] = True
        bottom_right = ~SelfAttentionMask.get_mask(
            concept_len, device, ret_parameter=False)
        mask[sent_len:sent_len + concept_len,
             sent_len:sent_len + concept_len] = bottom_right
        for group in spans:
            if group[0] >= sent_len:
                for i in range(len(group)):
                    for j in range(i + 1, len(group)):
                        mask[group[i], group[j]] = True
        attention_mask.append(mask)
        _token_type_ids = [0] * sent_len + [1] * concept_len
        token_type_ids.append(_token_type_ids)
        assert len(input_ids) == len(_token_type_ids)
        last_concept_offset.append(raw_concept_len - 1)
    ret['attention_mask'] = torch.stack(attention_mask)
    ret['concept_mask'] = PadSequenceDataLoader.pad_data(
        concept_mask, 0, torch.bool)
    ret['token_mask'] = PadSequenceDataLoader.pad_data(token_mask, 0,
                                                       torch.bool)
    ret['token_type_ids'] = PadSequenceDataLoader.pad_data(
        token_type_ids, 0, torch.long)
    ret['snt_len'] = PadSequenceDataLoader.pad_data(snt_len, 0, torch.long)
    ret['last_concept_offset'] = PadSequenceDataLoader.pad_data(
        last_concept_offset, 0, torch.long)
    return token_field