def __init__(self, tokenizer, doc_stride, max_seq_length, max_query_length): """ Parameters ---------- tokenizer The tokenizer doc_stride The stride to chunk the document max_seq_length Maximum length of the merged data max_query_length Maximum query length """ self._tokenizer = tokenizer self._doc_stride = doc_stride self._max_seq_length = max_seq_length self._max_query_length = max_query_length vocab = tokenizer.vocab self.pad_id = vocab.pad_id # For roberta model, taking sepecial token <s> as [CLS] and </s> as [SEP] self.cls_id = vocab.bos_id if 'cls_token' not in vocab.special_token_keys else vocab.cls_id self.sep_id = vocab.eos_id if 'sep_token' not in vocab.special_token_keys else vocab.sep_id # TODO(sxjscience) Consider to combine the NamedTuple and batchify functionality. self.ChunkFeature = collections.namedtuple('ChunkFeature', ['qas_id', 'data', 'valid_length', 'segment_ids', 'masks', 'is_impossible', 'gt_start', 'gt_end', 'context_offset', 'chunk_start', 'chunk_length']) self.BatchifyFunction = bf.NamedTuple(self.ChunkFeature, {'qas_id': bf.List(), 'data': bf.Pad(val=self.pad_id), 'valid_length': bf.Stack(), 'segment_ids': bf.Pad(), 'masks': bf.Pad(val=1), 'is_impossible': bf.Stack(), 'gt_start': bf.Stack(), 'gt_end': bf.Stack(), 'context_offset': bf.Stack(), 'chunk_start': bf.Stack(), 'chunk_length': bf.Stack()})
def test_list(): data = [object() for _ in range(5)] passthrough = batchify.List()(data) assert passthrough == data