示例#1
0
    def encode_long_text(self, long_text, batch=32):
        assert isinstance(long_text, str)

        split_text = split_with_overlap(
            long_text,
            max_length=self.albert_config['max_length'],
            overlap_window_length=self.albert_config['overlap_window'],
            tokenize_func=self.tokenizer.tokenize
        )  # NOTE: This is not fully correct. Has issues with sub-words (results do not differ much, however).

        encoded_splits = None
        _from = 0
        to = _from + batch
        while _from < len(split_text):
            encoded = self(split_text[_from:to]).numpy()
            if encoded_splits is None:
                encoded_splits = encoded
            else:
                encoded_splits = np.concatenate([encoded_splits, encoded],
                                                axis=0)
            _from = to
            to = _from + batch

        #encoded_splits = self(split_text).numpy()
        return self.aggregate_split_text(encoded_splits)
示例#2
0
 def test_split_2(self):
     expected = [
         'One two three four five.', 'four five. Six seven eight',
         'seven eight nine. Ten.'
     ]
     output = split_with_overlap(self.text,
                                 max_length=5,
                                 overlap_window_length=2)
     self.assertEqual(output, expected)
示例#3
0
    def test_max_length(self):
        max_length = 2
        output = split_with_overlap(self.text,
                                    max_length=max_length,
                                    overlap_window_length=1)

        for x in output:
            actual = len(x.split())
            self.assertLessEqual(actual, max_length)
示例#4
0
 def test_output_type(self):
     output = split_with_overlap(self.text, 2, 1)
     self.assertEqual(type(output), list)
示例#5
0
 def test_short(self):
     output = split_with_overlap(self.text, 100, 1)
     self.assertEqual(output, [self.text])