Exemple #1
0
 def tensorize(raw_batch: Dict[str, Any], vocabs: VocabDict, pad_dict: Dict[str, int] = None, device=None):
     for field, data in raw_batch.items():
         if isinstance(data, torch.Tensor):
             continue
         vocab_key = field[:-len('_id')] if field.endswith('_id') else None
         vocab: Vocab = vocabs.get(vocab_key, None) if vocabs and vocab_key else None
         if vocab:
             pad = vocab.safe_pad_token_idx
             dtype = torch.long
         elif pad_dict is not None and field in pad_dict:
             pad = pad_dict[field]
             dtype = dtype_of(pad)
         elif field.endswith('_offset') or field.endswith('_id') or field.endswith(
                 '_count') or field.endswith('_ids') or field.endswith('_score') or field.endswith(
             '_length') or field.endswith('_span'):
             # guess some common fields to pad
             pad = 0
             dtype = torch.long
         elif field.endswith('_mask'):
             pad = False
             dtype = torch.bool
         else:
             # no need to pad
             continue
         data = PadSequenceDataLoader.pad_data(data, pad, dtype)
         raw_batch[field] = data
     if device is not None:
         for field, data in raw_batch.items():
             if isinstance(data, torch.Tensor):
                 data = data.to(device)
                 raw_batch[field] = data
     return raw_batch
Exemple #2
0
def batchify(data,
             vocabs: VocabDict,
             unk_rate=0.,
             device=None,
             squeeze=False,
             tokenizer: TransformerSequenceTokenizer = None,
             shuffle_sibling=True,
             levi_graph=False,
             extra_arc=False,
             bart=False):
    rel_vocab: VocabWithFrequency = vocabs.rel
    _tok = list_to_tensor(data['token'], vocabs['token'],
                          unk_rate=unk_rate) if 'token' in vocabs else None
    _lem = list_to_tensor(data['lemma'], vocabs['lemma'], unk_rate=unk_rate)
    _pos = list_to_tensor(data['pos'], vocabs['pos'],
                          unk_rate=unk_rate) if 'pos' in vocabs else None
    _ner = list_to_tensor(data['ner'], vocabs['ner'],
                          unk_rate=unk_rate) if 'ner' in vocabs else None
    _word_char = lists_of_string_to_tensor(
        data['token'], vocabs['word_char']) if 'word_char' in vocabs else None

    local_token2idx = data['token2idx']
    local_idx2token = data['idx2token']
    _cp_seq = list_to_tensor(data['cp_seq'], vocabs['predictable_concept'],
                             local_token2idx)
    _mp_seq = list_to_tensor(data['mp_seq'], vocabs['predictable_concept'],
                             local_token2idx)

    ret = copy(data)
    if 'amr' in data:
        concept, edge = [], []
        for amr in data['amr']:
            if levi_graph == 'kahn':
                concept_i, edge_i = amr.to_levi(rel_vocab.get_frequency,
                                                shuffle=shuffle_sibling)
            else:
                concept_i, edge_i, _ = amr.root_centered_sort(
                    rel_vocab.get_frequency, shuffle=shuffle_sibling)
            concept.append(concept_i)
            edge.append(edge_i)
        if levi_graph is True:
            concept_with_rel, edge_with_rel = levi_amr(concept,
                                                       edge,
                                                       extra_arc=extra_arc)
            concept = concept_with_rel
            edge = edge_with_rel

        augmented_concept = [[DUM] + x + [END] for x in concept]

        _concept_in = list_to_tensor(augmented_concept,
                                     vocabs.get('concept_and_rel',
                                                vocabs['concept']),
                                     unk_rate=unk_rate)[:-1]
        _concept_char_in = lists_of_string_to_tensor(
            augmented_concept, vocabs['concept_char'])[:-1]
        _concept_out = list_to_tensor(augmented_concept,
                                      vocabs['predictable_concept'],
                                      local_token2idx)[1:]

        out_conc_len, bsz = _concept_out.shape
        _rel = np.full((1 + out_conc_len, bsz, out_conc_len),
                       rel_vocab.pad_idx)
        # v: [<dummy>, concept_0, ..., concept_l, ..., concept_{n-1}, <end>] u: [<dummy>, concept_0, ..., concept_l, ..., concept_{n-1}]

        for bidx, (x, y) in enumerate(zip(edge, concept)):
            for l, _ in enumerate(y):
                if l > 0:
                    # l=1 => pos=l+1=2
                    _rel[l + 1, bidx, 1:l + 1] = rel_vocab.get_idx(NIL)
            for v, u, r in x:
                if levi_graph:
                    r = 1
                else:
                    r = rel_vocab.get_idx(r)
                assert v > u, 'Invalid typological order'
                _rel[v + 1, bidx, u + 1] = r
        ret.update({
            'concept_in': _concept_in,
            'concept_char_in': _concept_char_in,
            'concept_out': _concept_out,
            'rel': _rel
        })
    else:
        augmented_concept = None

    token_length = ret.get('token_length', None)
    if token_length is not None and not isinstance(token_length, torch.Tensor):
        ret['token_length'] = torch.tensor(
            token_length,
            dtype=torch.long,
            device=device if
            (isinstance(device, torch.device) or device >= 0) else 'cpu:0')
    ret.update({
        'lem': _lem,
        'tok': _tok,
        'pos': _pos,
        'ner': _ner,
        'word_char': _word_char,
        'copy_seq': np.stack([_cp_seq, _mp_seq], -1),
        'local_token2idx': local_token2idx,
        'local_idx2token': local_idx2token
    })
    if squeeze:
        token_field = make_batch_for_squeeze(data, augmented_concept,
                                             tokenizer, device, ret)
    else:
        token_field = 'token'
    subtoken_to_tensor(token_field, ret)
    if bart:
        make_batch_for_bart(augmented_concept, ret, tokenizer, device)
    move_dict_to_device(ret, device)

    return ret