def tensorize(raw_batch: Dict[str, Any], vocabs: VocabDict, pad_dict: Dict[str, int] = None, device=None): for field, data in raw_batch.items(): if isinstance(data, torch.Tensor): continue vocab_key = field[:-len('_id')] if field.endswith('_id') else None vocab: Vocab = vocabs.get(vocab_key, None) if vocabs and vocab_key else None if vocab: pad = vocab.safe_pad_token_idx dtype = torch.long elif pad_dict is not None and field in pad_dict: pad = pad_dict[field] dtype = dtype_of(pad) elif field.endswith('_offset') or field.endswith('_id') or field.endswith( '_count') or field.endswith('_ids') or field.endswith('_score') or field.endswith( '_length') or field.endswith('_span'): # guess some common fields to pad pad = 0 dtype = torch.long elif field.endswith('_mask'): pad = False dtype = torch.bool else: # no need to pad continue data = PadSequenceDataLoader.pad_data(data, pad, dtype) raw_batch[field] = data if device is not None: for field, data in raw_batch.items(): if isinstance(data, torch.Tensor): data = data.to(device) raw_batch[field] = data return raw_batch
def batchify(data, vocabs: VocabDict, unk_rate=0., device=None, squeeze=False, tokenizer: TransformerSequenceTokenizer = None, shuffle_sibling=True, levi_graph=False, extra_arc=False, bart=False): rel_vocab: VocabWithFrequency = vocabs.rel _tok = list_to_tensor(data['token'], vocabs['token'], unk_rate=unk_rate) if 'token' in vocabs else None _lem = list_to_tensor(data['lemma'], vocabs['lemma'], unk_rate=unk_rate) _pos = list_to_tensor(data['pos'], vocabs['pos'], unk_rate=unk_rate) if 'pos' in vocabs else None _ner = list_to_tensor(data['ner'], vocabs['ner'], unk_rate=unk_rate) if 'ner' in vocabs else None _word_char = lists_of_string_to_tensor( data['token'], vocabs['word_char']) if 'word_char' in vocabs else None local_token2idx = data['token2idx'] local_idx2token = data['idx2token'] _cp_seq = list_to_tensor(data['cp_seq'], vocabs['predictable_concept'], local_token2idx) _mp_seq = list_to_tensor(data['mp_seq'], vocabs['predictable_concept'], local_token2idx) ret = copy(data) if 'amr' in data: concept, edge = [], [] for amr in data['amr']: if levi_graph == 'kahn': concept_i, edge_i = amr.to_levi(rel_vocab.get_frequency, shuffle=shuffle_sibling) else: concept_i, edge_i, _ = amr.root_centered_sort( rel_vocab.get_frequency, shuffle=shuffle_sibling) concept.append(concept_i) edge.append(edge_i) if levi_graph is True: concept_with_rel, edge_with_rel = levi_amr(concept, edge, extra_arc=extra_arc) concept = concept_with_rel edge = edge_with_rel augmented_concept = [[DUM] + x + [END] for x in concept] _concept_in = list_to_tensor(augmented_concept, vocabs.get('concept_and_rel', vocabs['concept']), unk_rate=unk_rate)[:-1] _concept_char_in = lists_of_string_to_tensor( augmented_concept, vocabs['concept_char'])[:-1] _concept_out = list_to_tensor(augmented_concept, vocabs['predictable_concept'], local_token2idx)[1:] out_conc_len, bsz = _concept_out.shape _rel = np.full((1 + out_conc_len, bsz, out_conc_len), rel_vocab.pad_idx) # v: [<dummy>, concept_0, ..., concept_l, ..., concept_{n-1}, <end>] u: [<dummy>, concept_0, ..., concept_l, ..., concept_{n-1}] for bidx, (x, y) in enumerate(zip(edge, concept)): for l, _ in enumerate(y): if l > 0: # l=1 => pos=l+1=2 _rel[l + 1, bidx, 1:l + 1] = rel_vocab.get_idx(NIL) for v, u, r in x: if levi_graph: r = 1 else: r = rel_vocab.get_idx(r) assert v > u, 'Invalid typological order' _rel[v + 1, bidx, u + 1] = r ret.update({ 'concept_in': _concept_in, 'concept_char_in': _concept_char_in, 'concept_out': _concept_out, 'rel': _rel }) else: augmented_concept = None token_length = ret.get('token_length', None) if token_length is not None and not isinstance(token_length, torch.Tensor): ret['token_length'] = torch.tensor( token_length, dtype=torch.long, device=device if (isinstance(device, torch.device) or device >= 0) else 'cpu:0') ret.update({ 'lem': _lem, 'tok': _tok, 'pos': _pos, 'ner': _ner, 'word_char': _word_char, 'copy_seq': np.stack([_cp_seq, _mp_seq], -1), 'local_token2idx': local_token2idx, 'local_idx2token': local_idx2token }) if squeeze: token_field = make_batch_for_squeeze(data, augmented_concept, tokenizer, device, ret) else: token_field = 'token' subtoken_to_tensor(token_field, ret) if bart: make_batch_for_bart(augmented_concept, ret, tokenizer, device) move_dict_to_device(ret, device) return ret