Beispiel #1
0
def test_named_tuple():
    a = ([1, 2, 3, 4], 0)
    b = ([5, 7], 1)
    c = ([1, 2, 3, 4, 5, 6, 7], 0)
    batchify_fn = batchify.NamedTuple([('data', batchify.Pad()),
                                       ('label', batchify.Stack())],
                                      name='SomeName')
    sample = batchify_fn([a, b, c])
    gt_data = batchify.Pad()([a[0], b[0], c[0]])
    gt_label = batchify.Stack()([a[1], b[1], c[1]])
    assert_allclose(sample.data.asnumpy(), gt_data.asnumpy())
    assert_allclose(sample.label.asnumpy(), gt_label.asnumpy())
    assert type(sample).__name__ == 'SomeName'
Beispiel #2
0
    def __init__(self, tokenizer, doc_stride, max_seq_length, max_query_length):
        """

        Parameters
        ----------
        tokenizer
            The tokenizer
        doc_stride
            The stride to chunk the document
        max_seq_length
            Maximum length of the merged data
        max_query_length
            Maximum query length
        """
        self._tokenizer = tokenizer
        self._doc_stride = doc_stride
        self._max_seq_length = max_seq_length
        self._max_query_length = max_query_length

        vocab = tokenizer.vocab
        self.pad_id = vocab.pad_id
        # For roberta model, taking sepecial token <s> as [CLS] and </s> as [SEP]
        self.cls_id = vocab.bos_id if 'cls_token' not in vocab.special_token_keys else vocab.cls_id
        self.sep_id = vocab.eos_id if 'sep_token' not in vocab.special_token_keys else vocab.sep_id

        # TODO(sxjscience) Consider to combine the NamedTuple and batchify functionality.
        self.ChunkFeature = collections.namedtuple('ChunkFeature',
                                              ['qas_id',
                                               'data',
                                               'valid_length',
                                               'segment_ids',
                                               'masks',
                                               'is_impossible',
                                               'gt_start',
                                               'gt_end',
                                               'context_offset',
                                               'chunk_start',
                                               'chunk_length'])
        self.BatchifyFunction = bf.NamedTuple(self.ChunkFeature,
                                         {'qas_id': bf.List(),
                                          'data': bf.Pad(val=self.pad_id),
                                          'valid_length': bf.Stack(),
                                          'segment_ids': bf.Pad(),
                                          'masks': bf.Pad(val=1),
                                          'is_impossible': bf.Stack(),
                                          'gt_start': bf.Stack(),
                                          'gt_end': bf.Stack(),
                                          'context_offset': bf.Stack(),
                                          'chunk_start': bf.Stack(),
                                          'chunk_length': bf.Stack()})
Beispiel #3
0
def test_dict():
    a = {'data': [1, 2, 3, 4], 'label': 0}
    b = {'data': [5, 7], 'label': 1}
    c = {'data': [1, 2, 3, 4, 5, 6, 7], 'label': 0}
    with pytest.raises(ValueError):
        wrong_batchify_fn = batchify.Dict(
            [batchify.Pad(pad_val=0),
             batchify.Stack()])
    with pytest.raises(ValueError):
        wrong_batchify_fn = batchify.NamedTuple(MyNamedTuple, {'a': 1, 'b': 2})
    batchify_fn = batchify.Dict({
        'data': batchify.Pad(pad_val=0),
        'label': batchify.Stack()
    })
    sample = batchify_fn([a, b, c])
    gt_data = batchify.Pad(pad_val=0)([a['data'], b['data'], c['data']])
    gt_label = batchify.Stack()([a['label'], b['label'], c['label']])
    assert isinstance(sample, dict)
    assert_allclose(sample['data'].asnumpy(), gt_data.asnumpy())
    assert_allclose(sample['label'].asnumpy(), gt_label.asnumpy())
Beispiel #4
0
def test_named_tuple():
    a = MyNamedTuple([1, 2, 3, 4], 0)
    b = MyNamedTuple([5, 7], 1)
    c = MyNamedTuple([1, 2, 3, 4, 5, 6, 7], 0)
    with pytest.raises(ValueError):
        wrong_batchify_fn = batchify.NamedTuple(
            MyNamedTuple, {
                'data0': batchify.Pad(pad_val=0),
                'label': batchify.Stack()
            })
    with pytest.raises(ValueError):
        wrong_batchify_fn = batchify.NamedTuple(
            MyNamedTuple,
            [batchify.Pad(pad_val=0),
             batchify.Stack(),
             batchify.Stack()])
    with pytest.raises(ValueError):
        wrong_batchify_fn = batchify.NamedTuple(MyNamedTuple,
                                                (batchify.Pad(pad_val=0), ))
    with pytest.raises(ValueError):
        wrong_batchify_fn = batchify.NamedTuple(MyNamedTuple, [1, 2])
    for batchify_fn in [
            batchify.NamedTuple(MyNamedTuple, {
                'data': batchify.Pad(pad_val=0),
                'label': batchify.Stack()
            }),
            batchify.NamedTuple(MyNamedTuple,
                                [batchify.Pad(pad_val=0),
                                 batchify.Stack()]),
            batchify.NamedTuple(MyNamedTuple,
                                (batchify.Pad(pad_val=0), batchify.Stack()))
    ]:
        sample = batchify_fn([a, b, c])
        gt_data = batchify.Pad(pad_val=0)([a[0], b[0], c[0]])
        gt_label = batchify.Stack()([a[1], b[1], c[1]])
        assert isinstance(sample, MyNamedTuple)
        assert_allclose(sample.data.asnumpy(), gt_data.asnumpy())
        assert_allclose(sample.label.asnumpy(), gt_label.asnumpy())
        with pytest.raises(ValueError):
            batchify_fn([1, 2, 3])