def test_named_tuple(): a = ([1, 2, 3, 4], 0) b = ([5, 7], 1) c = ([1, 2, 3, 4, 5, 6, 7], 0) batchify_fn = batchify.NamedTuple([('data', batchify.Pad()), ('label', batchify.Stack())], name='SomeName') sample = batchify_fn([a, b, c]) gt_data = batchify.Pad()([a[0], b[0], c[0]]) gt_label = batchify.Stack()([a[1], b[1], c[1]]) assert_allclose(sample.data.asnumpy(), gt_data.asnumpy()) assert_allclose(sample.label.asnumpy(), gt_label.asnumpy()) assert type(sample).__name__ == 'SomeName'
def __init__(self, tokenizer, doc_stride, max_seq_length, max_query_length): """ Parameters ---------- tokenizer The tokenizer doc_stride The stride to chunk the document max_seq_length Maximum length of the merged data max_query_length Maximum query length """ self._tokenizer = tokenizer self._doc_stride = doc_stride self._max_seq_length = max_seq_length self._max_query_length = max_query_length vocab = tokenizer.vocab self.pad_id = vocab.pad_id # For roberta model, taking sepecial token <s> as [CLS] and </s> as [SEP] self.cls_id = vocab.bos_id if 'cls_token' not in vocab.special_token_keys else vocab.cls_id self.sep_id = vocab.eos_id if 'sep_token' not in vocab.special_token_keys else vocab.sep_id # TODO(sxjscience) Consider to combine the NamedTuple and batchify functionality. self.ChunkFeature = collections.namedtuple('ChunkFeature', ['qas_id', 'data', 'valid_length', 'segment_ids', 'masks', 'is_impossible', 'gt_start', 'gt_end', 'context_offset', 'chunk_start', 'chunk_length']) self.BatchifyFunction = bf.NamedTuple(self.ChunkFeature, {'qas_id': bf.List(), 'data': bf.Pad(val=self.pad_id), 'valid_length': bf.Stack(), 'segment_ids': bf.Pad(), 'masks': bf.Pad(val=1), 'is_impossible': bf.Stack(), 'gt_start': bf.Stack(), 'gt_end': bf.Stack(), 'context_offset': bf.Stack(), 'chunk_start': bf.Stack(), 'chunk_length': bf.Stack()})
def test_dict(): a = {'data': [1, 2, 3, 4], 'label': 0} b = {'data': [5, 7], 'label': 1} c = {'data': [1, 2, 3, 4, 5, 6, 7], 'label': 0} with pytest.raises(ValueError): wrong_batchify_fn = batchify.Dict( [batchify.Pad(pad_val=0), batchify.Stack()]) with pytest.raises(ValueError): wrong_batchify_fn = batchify.NamedTuple(MyNamedTuple, {'a': 1, 'b': 2}) batchify_fn = batchify.Dict({ 'data': batchify.Pad(pad_val=0), 'label': batchify.Stack() }) sample = batchify_fn([a, b, c]) gt_data = batchify.Pad(pad_val=0)([a['data'], b['data'], c['data']]) gt_label = batchify.Stack()([a['label'], b['label'], c['label']]) assert isinstance(sample, dict) assert_allclose(sample['data'].asnumpy(), gt_data.asnumpy()) assert_allclose(sample['label'].asnumpy(), gt_label.asnumpy())
def test_named_tuple(): a = MyNamedTuple([1, 2, 3, 4], 0) b = MyNamedTuple([5, 7], 1) c = MyNamedTuple([1, 2, 3, 4, 5, 6, 7], 0) with pytest.raises(ValueError): wrong_batchify_fn = batchify.NamedTuple( MyNamedTuple, { 'data0': batchify.Pad(pad_val=0), 'label': batchify.Stack() }) with pytest.raises(ValueError): wrong_batchify_fn = batchify.NamedTuple( MyNamedTuple, [batchify.Pad(pad_val=0), batchify.Stack(), batchify.Stack()]) with pytest.raises(ValueError): wrong_batchify_fn = batchify.NamedTuple(MyNamedTuple, (batchify.Pad(pad_val=0), )) with pytest.raises(ValueError): wrong_batchify_fn = batchify.NamedTuple(MyNamedTuple, [1, 2]) for batchify_fn in [ batchify.NamedTuple(MyNamedTuple, { 'data': batchify.Pad(pad_val=0), 'label': batchify.Stack() }), batchify.NamedTuple(MyNamedTuple, [batchify.Pad(pad_val=0), batchify.Stack()]), batchify.NamedTuple(MyNamedTuple, (batchify.Pad(pad_val=0), batchify.Stack())) ]: sample = batchify_fn([a, b, c]) gt_data = batchify.Pad(pad_val=0)([a[0], b[0], c[0]]) gt_label = batchify.Stack()([a[1], b[1], c[1]]) assert isinstance(sample, MyNamedTuple) assert_allclose(sample.data.asnumpy(), gt_data.asnumpy()) assert_allclose(sample.label.asnumpy(), gt_label.asnumpy()) with pytest.raises(ValueError): batchify_fn([1, 2, 3])