コード例 #1
0
ファイル: test_sequence.py プロジェクト: seeker1943/finetune
 def test_fit_lm_only(self):
     """
     Ensure model training does not error out
     Ensure model returns predictions
     """
     raw_docs = ["".join(text) for text in self.texts]
     texts, annotations = finetune_to_indico_sequence(raw_docs, self.texts, self.labels,
                                                      none_value=self.model.config.pad_token)
     train_texts, test_texts, train_annotations, test_annotations = train_test_split(texts, annotations, test_size=0.1)
     self.model.fit(train_texts)
     self.model.fit(train_texts, train_annotations)
     predictions = self.model.predict(test_texts)
     probas = self.model.predict_proba(test_texts)
     self.assertIsInstance(probas, list)
     self.assertIsInstance(probas[0], list)
     self.assertIsInstance(probas[0][0], dict)
     self.assertIsInstance(probas[0][0]['confidence'], dict)
     token_precision = sequence_labeling_token_precision(test_annotations, predictions)
     token_recall = sequence_labeling_token_recall(test_annotations, predictions)
     overlap_precision = sequence_labeling_overlap_precision(test_annotations, predictions)
     overlap_recall = sequence_labeling_overlap_recall(test_annotations, predictions)
     self.assertIn('Named Entity', token_precision)
     self.assertIn('Named Entity', token_recall)
     self.assertIn('Named Entity', overlap_precision)
     self.assertIn('Named Entity', overlap_recall)
     self.model.save(self.save_file)
     model = SequenceLabeler.load(self.save_file)
     predictions = model.predict(test_texts)
コード例 #2
0
 def test_fit_predict_multi_model(self):
     """
     Ensure model training does not error out
     Ensure model returns predictions
     """
     self.model = SequenceLabeler(batch_size=2,
                                  max_length=256,
                                  lm_loss_coef=0.0,
                                  multi_label_sequences=True)
     raw_docs = ["".join(text) for text in self.texts]
     texts, annotations = finetune_to_indico_sequence(
         raw_docs,
         self.texts,
         self.labels,
         none_value=self.model.config.pad_token)
     train_texts, test_texts, train_annotations, _ = train_test_split(
         texts, annotations, test_size=0.1)
     self.model.fit(train_texts, train_annotations)
     self.model.predict(test_texts)
     probas = self.model.predict_proba(test_texts)
     self.assertIsInstance(probas, list)
     self.assertIsInstance(probas[0], list)
     self.assertIsInstance(probas[0][0], dict)
     self.assertIsInstance(probas[0][0]['confidence'], dict)
     self.model.save(self.save_file)
     model = SequenceLabeler.load(self.save_file)
     model.predict(test_texts)
コード例 #3
0
ファイル: test_sequence.py プロジェクト: seeker1943/finetune
    def test_cached_predict(self):
        """
        Ensure model training does not error out
        Ensure model returns predictions
        """
        raw_docs = ["".join(text) for text in self.texts]
        texts, annotations = finetune_to_indico_sequence(raw_docs, self.texts, self.labels,
                                                         none_value=self.model.config.pad_token)
        train_texts, test_texts, train_annotations, _ = train_test_split(texts, annotations, test_size=0.1)
        self.model.fit(train_texts, train_annotations)
        
        self.model.config.chunk_long_sequences = True
        self.model.config.max_length = 128

        uncached_preds = self.model.predict(test_texts[:1])

        with self.model.cached_predict():
            start = time.time()
            self.model.predict(test_texts[:1])
            first = time.time()
            self.model.predict(test_texts[:1])
            second = time.time()
            preds = self.model.predict(test_texts[:1])
            assert len(preds) == 1
            preds = self.model.predict(test_texts[:2])
            assert len(preds) == 2

        for uncached_pred, cached_pred in zip(uncached_preds, preds):
            self.assertEqual(str(uncached_pred), str(cached_pred))

        first_prediction_time = (first - start)
        second_prediction_time = (second - first)
        self.assertLess(second_prediction_time, first_prediction_time / 2.)
コード例 #4
0
ファイル: test_utils.py プロジェクト: seeker1943/finetune
    def test_overlapping_gpt2_subtokens(self):
        raw = ["Indico Is the best hey"]
        finetunex = [["Indico", " Is the", " best", " hey"]]
        finetuney = [[("1", ), ("1", "2"), ("2", ), ("<PAD>")]]
        encoder = GPT2Encoder()
        indicox_pred, indicoy_pred = finetune_to_indico_sequence(
            raw,
            finetunex,
            finetuney,
            none_value="<PAD>",
            subtoken_predictions=True)

        indicoy = [[
            {
                'start': 0,
                'end': 13,
                'label': '1',
                'text': 'Indico Is the'
            },
            {
                'start': 6,
                'end': 18,
                'label': '2',
                'text': ' Is the best'
            },
        ]]
        self.assertEqual(indicoy, indicoy_pred)
        self.assertEqual(raw, indicox_pred)

        finetunex_pred, finetuney_pred, *_ = indico_to_finetune_sequence(
            raw, indicoy, encoder=encoder, none_value="<PAD>")
        self.assertEqual(finetunex_pred, finetunex)
        self.assertCountEqual(finetuney[0][0], finetuney_pred[0][0])
        self.assertCountEqual(finetuney[0][1], finetuney_pred[0][1])
        self.assertCountEqual(finetuney[0][2], finetuney_pred[0][2])
コード例 #5
0
ファイル: test_utils.py プロジェクト: seeker1943/finetune
    def test_nested_labels(self):
        raw = ["Indico Is the best"]
        finetunex = [["Indico ", "Is the", " best"]]
        finetuney = [[("1", ), ("1", "2", "3"), ("1", )]]
        encoder = GPTEncoder()
        indicox_pred, indicoy_pred = finetune_to_indico_sequence(
            raw, finetunex, finetuney, none_value="<PAD>")

        indicoy = [[{
            'start': 0,
            'end': 18,
            'label': '1',
            'text': 'Indico Is the best'
        }, {
            'start': 7,
            'end': 13,
            'label': '2',
            'text': 'Is the'
        }, {
            'start': 7,
            'end': 13,
            'label': '3',
            'text': 'Is the'
        }]]
        self.assertEqual(indicoy, indicoy_pred)
        self.assertEqual(raw, indicox_pred)

        finetunex_pred, finetuney_pred, *_ = indico_to_finetune_sequence(
            raw, indicoy, encoder=encoder, none_value="<PAD>")
        self.assertEqual(finetunex_pred, finetunex)
        self.assertCountEqual(finetuney[0][0], finetuney_pred[0][0])
        self.assertCountEqual(finetuney[0][1], finetuney_pred[0][1])
        self.assertCountEqual(finetuney[0][2], finetuney_pred[0][2])
コード例 #6
0
ファイル: test_utils.py プロジェクト: yishuihanhan/finetune
    def test_three_overlapping_labels(self):
        raw = ["Indico Is the very best"]
        finetunex = [
            ["Indico ", "Is the very", " best"]
        ]
        finetuney = [
            [("<PAD>", ), ("1", "2", "3"), ("1", "3")]
        ]
        encoder = GPTEncoder()
        indicox_pred, indicoy_pred = finetune_to_indico_sequence(raw, finetunex, finetuney, none_value="<PAD>")
        indicoy_pred = [sorted(seq, key=lambda x: x['label']) for seq in indicoy_pred]
        indicoy = [
            sorted(
                [
                    {'start': 7, 'end': 18, 'label': '2', 'text': 'Is the very'},
                    {'start': 7, 'end': 23, 'label': '1', 'text': 'Is the very best'},
                    {'start': 7, 'end': 23, 'label': '3', 'text': 'Is the very best'}
                ],
                key=lambda x: x['label']
            )
        ]
        self.assertEqual(indicoy, indicoy_pred)
        self.assertEqual(raw, indicox_pred)

        finetunex_pred, finetuney_pred, *_ = indico_to_finetune_sequence(
            raw, indicoy, encoder=encoder, none_value="<PAD>"
        )
        self.assertEqual(finetunex_pred, finetunex)
        self.assertCountEqual(finetuney[0][0], finetuney_pred[0][0])
        self.assertCountEqual(finetuney[0][1], finetuney_pred[0][1])
        self.assertCountEqual(finetuney[0][2], finetuney_pred[0][2])
コード例 #7
0
ファイル: test_utils.py プロジェクト: yishuihanhan/finetune
    def test_whitespace_handling(self):
        # Newline complications
        finetunex = [["Train:", "\n\n\n and test", " tokenization must be", " equivalent"]]
        finetuney = [[("1",), ("1", "2"), ("2",), ("<PAD>",)]]
        expectedx = ["Train:\n\n\n and test tokenization must be equivalent"]
        expectedy = [
            [
                {'start': 0, 'end': 18, 'label': "1", 'text': "Train:\n\n\n and test"},
                {'start': 10, 'end': 39, 'label': "2", 'text': "and test tokenization must be"}
            ]
        ]
        indicox_pred, indicoy_pred = finetune_to_indico_sequence(expectedx, finetunex, finetuney, none_value="<PAD>", subtoken_predictions=False)
        self.assertEqual(indicox_pred, expectedx)
        self.assertEqual(indicoy_pred, expectedy)
    
        expectedx = ["Train and test tokenization must be equivalent"]
        expectedy = [
            [
                {'start': 0, 'end': 14, 'label': "1", 'text': "Train and test"},
                {'start': 6, 'end': 35, 'label': "2", 'text': "and test tokenization must be"}
            ]
        ]
    
        # Spaces before labels
        finetunex = [["Train", " and test", " tokenization must be", " equivalent"]]
        finetuney = [[("1",), ("1", "2"), ("2",), ("<PAD>",)]]

        indicox_pred, indicoy_pred = finetune_to_indico_sequence(expectedx, finetunex, finetuney, none_value="<PAD>", subtoken_predictions=False)
        self.assertEqual(indicox_pred, expectedx)
        self.assertEqual(indicoy_pred, expectedy)

        # Spaces after labels
        finetunex = [["Train ", "and test ", "tokenization must be ", "equivalent"]]
        finetuney = [[("1",), ("1", "2"), ("2",), ("<PAD>",)]]
    
        indicox_pred, indicoy_pred = finetune_to_indico_sequence(expectedx, finetunex, finetuney, none_value="<PAD>", subtoken_predictions=False)
        self.assertEqual(indicox_pred, expectedx)
        self.assertEqual(indicoy_pred, expectedy)

        # Whitespace anarchy
        finetunex = [["Train", " and test ", "tokenization must be", " equivalent"]]
        finetuney = [[("1",), ("1", "2"), ("2",), ("<PAD>",)]]

        indicox_pred, indicoy_pred = finetune_to_indico_sequence(expectedx, finetunex, finetuney, none_value="<PAD>", subtoken_predictions=False)
        self.assertEqual(indicox_pred, expectedx)
        self.assertEqual(indicoy_pred, expectedy)
コード例 #8
0
    def test_fit_predict(self):
        """
        Ensure model training does not error out
        Ensure model returns predictions
        Ensure class reweighting behaves as intended
        """
        raw_docs = ["".join(text) for text in self.texts]
        texts, annotations = finetune_to_indico_sequence(
            raw_docs, self.texts, self.labels, none_value=self.model.config.pad_token
        )
        train_texts, test_texts, train_annotations, test_annotations = train_test_split(
            texts, annotations, test_size=0.1
        )

        reweighted_model = SequenceLabeler(
            **self.default_config(class_weights={"Named Entity": 100.0})
        )
        reweighted_model.fit(train_texts, train_annotations)
        reweighted_predictions = reweighted_model.predict(test_texts)
        reweighted_token_recall = sequence_labeling_token_recall(
            test_annotations, reweighted_predictions
        )

        self.model.fit(train_texts, train_annotations)
        predictions = self.model.predict(test_texts)
        probas = self.model.predict_proba(test_texts)

        self.assertIsInstance(probas, list)
        self.assertIsInstance(probas[0], list)
        self.assertIsInstance(probas[0][0], dict)
        self.assertIsInstance(probas[0][0]["confidence"], dict)

        token_precision = sequence_labeling_token_precision(
            test_annotations, predictions
        )
        token_recall = sequence_labeling_token_recall(test_annotations, predictions)
        overlap_precision = sequence_labeling_overlap_precision(
            test_annotations, predictions
        )
        overlap_recall = sequence_labeling_overlap_recall(test_annotations, predictions)

        self.assertIn("Named Entity", token_precision)
        self.assertIn("Named Entity", token_recall)
        self.assertIn("Named Entity", overlap_precision)
        self.assertIn("Named Entity", overlap_recall)

        self.model.save(self.save_file)

        self.assertGreater(
            reweighted_token_recall["Named Entity"], token_recall["Named Entity"]
        )
コード例 #9
0
 def test_cached_predict(self):
     """
     Ensure model training does not error out
     Ensure model returns predictions
     """
     raw_docs = ["".join(text) for text in self.texts]
     texts, annotations = finetune_to_indico_sequence(
         raw_docs, self.texts, self.labels, none_value=self.model.config.pad_token
     )
     train_texts, test_texts, train_annotations, _ = train_test_split(
         texts, annotations, test_size=0.1
     )
     self.model.fit(train_texts, train_annotations)
     with self.model.cached_predict():
         self.model.predict(test_texts)
         self.model.predict(test_texts)
コード例 #10
0
    def download(self):

        url = "https://raw.githubusercontent.com/dice-group/n3-collection/master/reuters.xml"
        r = requests.get(url)

        with open(XML_PATH, 'wb') as fd:
            fd.write(r.content)

        fd = open(XML_PATH)
        soup = bs(fd, "html.parser")
        docs = []
        docs_labels = []
        for elem in soup.find_all("document"):
            texts = []
            labels = []

            # Loop through each child of the element under "textwithnamedentities"
            for c in elem.find("textwithnamedentities").children:
                if type(c) == Tag:
                    if c.name == "namedentityintext":
                        label = "Named Entity"  # part of a named entity
                    else:
                        label = "<PAD>"  # irrelevant word
                    texts.append(c.text)
                    labels.append(label)

            docs.append(texts)
            docs_labels.append(labels)

        fd.close()
        os.remove(XML_PATH)

        raw_texts = ["".join(doc) for doc in docs]
        texts, annotations = finetune_to_indico_sequence(
            raw_texts,
            docs,
            docs_labels,
            none_value="<PAD>",
            subtoken_predictions=True)
        df = pd.DataFrame({
            'texts':
            texts,
            'annotations':
            [json.dumps(annotation) for annotation in annotations]
        })
        df.to_csv(DATA_PATH)
コード例 #11
0
    def predict(self, X, per_token=False):
        """
        Produces a list of most likely class labels as determined by the fine-tuned model.

        :param X: A list / array of text, shape [batch]
        :param per_token: If True, return raw probabilities and labels on a per token basis
        :returns: list of class labels.
        """
        if self.config.use_auxiliary_info:
            X_with_context = copy.deepcopy(X)
            X = X[0]
        else:
            X_with_context = X
        all_subseqs = []
        all_labels = []
        all_probs = []
        all_positions = []
        chunk_size = self.config.max_length - 2
        step_size = chunk_size // 3
        doc_idx = -1
        for (
                position_seq,
                start_of_doc,
                end_of_doc,
                label_seq,
                proba_seq,
        ) in self.process_long_sequence(X_with_context):
            start, end = 0, None
            if start_of_doc:
                # if this is the first chunk in a document, start accumulating from scratch
                doc_subseqs = []
                doc_labels = []
                doc_probs = []
                doc_positions = []
                doc_starts = []

                doc_idx += 1
                start_of_token = 0
                if not end_of_doc:
                    end = step_size * 2
            else:
                if end_of_doc:
                    # predict on the rest of sequence
                    start = step_size
                else:
                    # predict only on middle third
                    start, end = step_size, step_size * 2

            label_seq = label_seq[start:end]
            position_seq = position_seq[start:end]
            proba_seq = proba_seq[start:end]

            for label, position, proba in zip(label_seq, position_seq,
                                              proba_seq):
                if position == -1:
                    # indicates padding / special tokens
                    continue

                # if there are no current subsequence
                # or the current subsequence has the wrong label
                if not doc_subseqs or label != doc_labels[-1] or per_token:
                    # start new subsequence
                    doc_subseqs.append(X[doc_idx][start_of_token:position])
                    doc_labels.append(label)
                    doc_probs.append([proba])
                    doc_positions.append((start_of_token, position))
                    doc_starts.append(start_of_token)
                else:
                    # continue appending to current subsequence
                    doc_subseqs[-1] = X[doc_idx][doc_starts[-1]:position]
                    doc_probs[-1].append(proba)
                start_of_token = position

            if end_of_doc:
                # last chunk in a document
                prob_dicts = []
                for prob_seq in doc_probs:
                    # format probabilities as dictionary
                    probs = np.mean(np.vstack(prob_seq), axis=0)
                    prob_dicts.append(
                        dict(
                            zip(self.input_pipeline.label_encoder.classes_,
                                probs)))
                    if self.multi_label:
                        del prob_dicts[-1][self.config.pad_token]

                all_subseqs.append(doc_subseqs)
                all_labels.append(doc_labels)
                all_probs.append(prob_dicts)
                all_positions.append(doc_positions)

        _, doc_annotations = finetune_to_indico_sequence(
            raw_texts=X,
            subseqs=all_subseqs,
            labels=all_labels,
            probs=all_probs,
            none_value=self.config.pad_token,
            subtoken_predictions=self.config.subtoken_predictions,
        )

        if per_token:
            return [{
                "tokens":
                _spacy_token_predictions(
                    raw_text=raw_text,
                    tokens=tokens,
                    probas=probas,
                    positions=positions,
                ),
                "prediction":
                predictions,
            } for raw_text, tokens, labels, probas, positions, predictions in
                    zip(
                        X,
                        all_subseqs,
                        all_labels,
                        all_probs,
                        all_positions,
                        doc_annotations,
                    )]
        else:
            return doc_annotations
コード例 #12
0
    def predict(self, X, per_token=False):
        """
        Produces a list of most likely class labels as determined by the fine-tuned model.

        :param X: A list / array of text, shape [batch]
        :param per_token: If True, return raw probabilities and labels on a per token basis
        :returns: list of class labels.
        """
        chunk_size = self.config.max_length - 2
        step_size = chunk_size // 3
        arr_encoded = list(
            itertools.chain.from_iterable(
                self.input_pipeline._text_to_ids([x]) for x in X))
        labels, batch_probas = [], []
        for pred in self._inference(
                X,
                predict_keys=[PredictMode.PROBAS, PredictMode.NORMAL],
                n_examples=len(arr_encoded)):
            labels.append(
                self.input_pipeline.label_encoder.inverse_transform(
                    pred[PredictMode.NORMAL]))
            batch_probas.append(pred[PredictMode.PROBAS])

        all_subseqs = []
        all_labels = []
        all_probs = []
        all_positions = []

        doc_idx = -1
        for chunk_idx, (label_seq,
                        proba_seq) in enumerate(zip(labels, batch_probas)):
            position_seq = arr_encoded[chunk_idx].char_locs
            start_of_doc = arr_encoded[chunk_idx].token_ids[0][
                0] == self.input_pipeline.text_encoder.start
            end_of_doc = (chunk_idx + 1 >= len(arr_encoded)
                          or arr_encoded[chunk_idx + 1].token_ids[0][0]
                          == self.input_pipeline.text_encoder.start)
            """
            Chunk idx for prediction.  Dividers at `step_size` increments.
            [  1  |  1  |  2  |  3  |  3  ]
            """
            start, end = 0, None
            if start_of_doc:
                # if this is the first chunk in a document, start accumulating from scratch
                doc_subseqs = []
                doc_labels = []
                doc_probs = []
                doc_positions = []
                doc_starts = []

                doc_idx += 1
                start_of_token = 0
                if not end_of_doc:
                    end = step_size * 2
            else:
                if end_of_doc:
                    # predict on the rest of sequence
                    start = step_size
                else:
                    # predict only on middle third
                    start, end = step_size, step_size * 2

            label_seq = label_seq[start:end]
            position_seq = position_seq[start:end]
            proba_seq = proba_seq[start:end]

            for label, position, proba in zip(label_seq, position_seq,
                                              proba_seq):
                if position == -1:
                    # indicates padding / special tokens
                    continue

                # if there are no current subsequence
                # or the current subsequence has the wrong label
                if not doc_subseqs or label != doc_labels[-1] or per_token:
                    # start new subsequence
                    doc_subseqs.append(X[doc_idx][start_of_token:position])
                    doc_labels.append(label)
                    doc_probs.append([proba])
                    doc_positions.append((start_of_token, position))
                    doc_starts.append(start_of_token)
                else:
                    # continue appending to current subsequence
                    doc_subseqs[-1] = X[doc_idx][doc_starts[-1]:position]
                    doc_probs[-1].append(proba)
                start_of_token = position

            if end_of_doc:
                # last chunk in a document
                prob_dicts = []
                for prob_seq in doc_probs:
                    # format probabilities as dictionary
                    probs = np.mean(np.vstack(prob_seq), axis=0)
                    prob_dicts.append(
                        dict(
                            zip(self.input_pipeline.label_encoder.classes_,
                                probs)))
                    if self.multi_label:
                        del prob_dicts[-1][self.config.pad_token]

                all_subseqs.append(doc_subseqs)
                all_labels.append(doc_labels)
                all_probs.append(prob_dicts)
                all_positions.append(doc_positions)

        _, doc_annotations = finetune_to_indico_sequence(
            raw_texts=X,
            subseqs=all_subseqs,
            labels=all_labels,
            probs=all_probs,
            none_value=self.config.pad_token,
            subtoken_predictions=self.config.subtoken_predictions)

        if per_token:
            return [{
                'tokens':
                _spacy_token_predictions(raw_text=raw_text,
                                         tokens=tokens,
                                         probas=probas,
                                         positions=positions),
                'prediction':
                predictions,
            } for raw_text, tokens, labels, probas, positions, predictions in
                    zip(X, all_subseqs, all_labels, all_probs, all_positions,
                        doc_annotations)]
        else:
            return doc_annotations
コード例 #13
0
ファイル: association.py プロジェクト: yishuihanhan/finetune
    def predict(self, X):
        """
        Produces a list of most likely class labels as determined by the fine-tuned model.

        :param X: A list / array of text, shape [batch]
        :returns: list of class labels.
        """
        pad_token = [self.config.pad_token
                     ] if self.multi_label else self.config.pad_token
        if self.config.viable_edges is None:
            LOGGER.warning(
                "config.viable_edges is not set, this is probably incorrect.")

        #TODO(Ben) combine this into the sequence labeling model??

        chunk_size = self.config.max_length - 2
        step_size = chunk_size // 3
        arr_encoded = list(
            itertools.chain.from_iterable(
                self.input_pipeline._text_to_ids(
                    [x], pad_token=(pad_token, pad_token, -1, -2)) for x in X))
        lens = [len(a.char_locs) for a in arr_encoded]
        labels, batch_probas, associations = [], [], []
        predict_keys = [
            PredictMode.SEQUENCE, PredictMode.SEQUENCE_PROBAS,
            PredictMode.ASSOCIATION, PredictMode.ASSOCIATION_PROBAS
        ]
        for l, pred in zip(lens, self._inference(X,
                                                 predict_keys=predict_keys)):
            pred_labels = self.input_pipeline.label_encoder.inverse_transform(
                pred[PredictMode.SEQUENCE])
            pred_labels = [
                label if i < l else self.config.pad_token
                for i, label in enumerate(pred_labels)
            ]
            labels.append(pred_labels)
            batch_probas.append(pred[PredictMode.SEQUENCE_PROBAS])
            pred["association_probs"] = self.prune_probs(
                pred[PredictMode.ASSOCIATION_PROBAS], pred_labels)
            most_likely_associations, most_likely_class_id = zip(*[
                np.unravel_index(np.argmax(a, axis=None), a.shape)
                for a in pred[PredictMode.ASSOCIATION_PROBAS]
            ])
            associations.append(
                (most_likely_associations,
                 self.input_pipeline.association_encoder.inverse_transform(
                     most_likely_class_id), [
                         prob[idx, cls] for prob, idx, cls in zip(
                             pred["association_probs"],
                             most_likely_associations, most_likely_class_id)
                     ]))
        all_subseqs = []
        all_labels = []
        all_probs = []
        all_assocs = []

        doc_idx = -1
        for chunk_idx, (label_seq, proba_seq, association) in enumerate(
                zip(labels, batch_probas, associations)):
            association_idx, association_class, association_prob = association

            position_seq = arr_encoded[chunk_idx].char_locs
            start_of_doc = arr_encoded[chunk_idx].token_ids[0][
                0] == self.input_pipeline.text_encoder.start
            end_of_doc = (chunk_idx + 1 >= len(arr_encoded)
                          or arr_encoded[chunk_idx + 1].token_ids[0][0]
                          == self.input_pipeline.text_encoder.start)
            start, end = 0, None
            if start_of_doc:
                # if this is the first chunk in a document, start accumulating from scratch
                doc_subseqs = []
                doc_labels = []
                doc_probs = []
                doc_assocs = []
                doc_idx += 1
                start_of_token = 0
                if not end_of_doc:
                    end = step_size * 2
            else:
                if end_of_doc:
                    # predict on the rest of sequence
                    start = step_size
                else:
                    # predict only on middle third
                    start, end = step_size, step_size * 2

            label_seq = label_seq[start:end]
            position_seq = position_seq[start:end]
            proba_seq = proba_seq[start:end]
            tok_idx_to_subseq = dict()
            for tok_idx, (label, position, proba) in enumerate(
                    zip(label_seq, position_seq, proba_seq)):
                if position == -1:
                    # indicates padding / special tokens
                    continue

                # if there are no current subsequence
                # or the current subsequence has the wrong label
                if not doc_subseqs or label != doc_labels[-1]:
                    # start new subsequence
                    doc_subseqs.append(X[doc_idx][start_of_token:position])
                    doc_probs.append([proba])
                    doc_assocs.append([
                        (tok_idx, association_idx[tok_idx],
                         association_class[tok_idx], association_prob[tok_idx])
                    ])
                    doc_labels.append(label)
                else:
                    # continue appending to current subsequence
                    doc_subseqs[-1] += X[doc_idx][start_of_token:position]
                    doc_probs[-1].append(proba)
                    doc_assocs[-1].append((tok_idx, association_idx[tok_idx],
                                           association_class[tok_idx],
                                           association_prob[tok_idx]))
                tok_idx_to_subseq[tok_idx] = len(doc_labels) - 1
                start_of_token = position

            if end_of_doc:
                # last chunk in a document
                prob_dicts = []
                for prob_seq in doc_probs:
                    # format probabilities as dictionary
                    probs = np.mean(np.vstack(prob_seq), axis=0)
                    prob_dicts.append(
                        dict(
                            zip(self.input_pipeline.label_encoder.classes_,
                                probs)))
                    if self.multi_label:
                        del prob_dicts[-1][self.config.pad_token]

                doc_assocs_by_idx = []
                for assoc in doc_assocs:
                    doc_assocs_by_idx.append([])
                    for from_idx, to_idx, cls, prob in assoc:
                        if from_idx in tok_idx_to_subseq and to_idx in tok_idx_to_subseq:
                            doc_assocs_by_idx[-1].append(
                                (tok_idx_to_subseq[from_idx],
                                 tok_idx_to_subseq[to_idx], cls, prob))

                all_subseqs.append(doc_subseqs)
                all_labels.append(doc_labels)
                all_probs.append(prob_dicts)
                all_assocs.append(doc_assocs_by_idx)

        _, doc_annotations = finetune_to_indico_sequence(
            raw_texts=X,
            subseqs=all_subseqs,
            labels=all_labels,
            probs=all_probs,
            none_value=self.config.pad_token,
            subtoken_predictions=self.config.subtoken_predictions,
            associations=all_assocs)
        return doc_annotations