def test_should_not_fail_on_shuffle_dataset(self, preprocessor,
                                             embeddings):
     data_generator = DataGenerator(np.asarray([SENTENCE_TOKENS_1]),
                                    np.asarray([[LABEL_1]]),
                                    features=np.asarray(
                                        [SENTENCE_FEATURES_1]),
                                    preprocessor=preprocessor,
                                    embeddings=embeddings,
                                    **DEFAULT_ARGS)
     data_generator._shuffle_dataset()  # pylint: disable=protected-access
 def test_should_return_left_pad_batch_values(self, preprocessor,
                                              embeddings):
     preprocessor.return_features = True
     batch_size = 2
     batch = DataGenerator(
         np.asarray([LONG_SENTENCE_TOKENS, SHORT_SENTENCE_TOKENS]),
         np.asarray([[LABEL_1] * len(LONG_SENTENCE_TOKENS),
                     [LABEL_2] * len(SHORT_SENTENCE_TOKENS)]),
         features=np.asarray([
             np.asarray([WORD_FEATURES_1] * len(LONG_SENTENCE_TOKENS)),
             np.asarray([WORD_FEATURES_2] * len(SHORT_SENTENCE_TOKENS))
         ]),
         preprocessor=preprocessor,
         embeddings=embeddings,
         **{
             **DEFAULT_ARGS, 'batch_size': batch_size
         })[0]
     LOGGER.debug('batch: %s', batch)
     assert len(batch) == 2
     inputs, _ = batch
     batch_x = inputs[0]
     batch_features = inputs[2]
     assert get_lengths(batch_x) == [len(LONG_SENTENCE_TOKENS)] * batch_size
     assert get_lengths(batch_features) == [len(LONG_SENTENCE_TOKENS)
                                            ] * batch_size
 def test_should_truncate_using_max_sequence_length_if_already_tokenized(
         self, preprocessor, embeddings):
     preprocessor.return_features = True
     batch_size = 2
     LOGGER.debug('SHORT_SENTENCE_TOKENS: %s', SHORT_SENTENCE_TOKENS)
     LOGGER.debug('LONG_SENTENCE_TOKENS: %s', LONG_SENTENCE_TOKENS)
     batch = DataGenerator(
         np.asarray([LONG_SENTENCE_TOKENS, SHORT_SENTENCE_TOKENS]),
         np.asarray([[LABEL_1] * len(LONG_SENTENCE_TOKENS),
                     [LABEL_2] * len(SHORT_SENTENCE_TOKENS)]),
         features=np.asarray([
             np.asarray([WORD_FEATURES_1] * len(LONG_SENTENCE_TOKENS)),
             np.asarray([WORD_FEATURES_2] * len(SHORT_SENTENCE_TOKENS))
         ]),
         preprocessor=preprocessor,
         embeddings=embeddings,
         max_sequence_length=len(SHORT_SENTENCE_TOKENS),
         **{
             **DEFAULT_ARGS, 'batch_size': batch_size,
             'tokenize': False
         })[0]
     LOGGER.debug('batch: %s', batch)
     assert len(batch) == 2
     inputs, batch_y = batch
     batch_x = inputs[0]
     batch_features = inputs[2]
     assert get_lengths(batch_x) == [len(SHORT_SENTENCE_TOKENS)
                                     ] * batch_size
     assert get_lengths(batch_features) == [len(SHORT_SENTENCE_TOKENS)
                                            ] * batch_size
     assert get_lengths(batch_y) == [len(SHORT_SENTENCE_TOKENS)
                                     ] * batch_size
 def test_should_concatenate_word_embeddings_when_using_text_feature(
         self, preprocessor, embeddings):
     preprocessor.return_casing = False
     sentence_tokens_1 = [WORD_1, WORD_2]
     text_1 = ' '.join([WORD_3, WORD_4])
     text_2 = ' '.join([WORD_4, PAD])
     features_1 = [[[text_1], [text_2]]]
     expected_char_indices_1 = get_words_char_indices([text_1, text_2])
     # by default only using the first word token for embeddings
     expected_word_vectors_1 = np.concatenate([
         get_word_vectors([WORD_3, WORD_4]),
         get_word_vectors([WORD_4, PAD])
     ],
                                              axis=-1)
     item = DataGenerator(np.asarray([sentence_tokens_1]),
                          np.asarray([[LABEL_1]]),
                          features=np.asarray(features_1),
                          preprocessor=preprocessor,
                          embeddings=embeddings,
                          text_feature_indices=[0],
                          concatenated_embeddings_token_count=2,
                          **DEFAULT_ARGS)[0]
     LOGGER.debug('item: %s', item)
     assert len(item) == 2
     x, labels = item
     assert all_close(labels, get_label_indices([LABEL_1]))
     assert all_close(x[0], expected_word_vectors_1)
     LOGGER.debug('expected char_indices_1: %s', expected_char_indices_1)
     LOGGER.debug('x[1]: %s', x[1])
     assert all_close(x[1], expected_char_indices_1)
     assert all_close(x[-1], [len(sentence_tokens_1)])
 def test_should_concatenate_word_embeddings_if_using_multiple_tokens(
         self, preprocessor, embeddings):
     preprocessor.return_casing = False
     sentence_tokens_1 = [WORD_1, WORD_2]
     feature_tokens_1 = [WORD_3, WORD_4]
     item = DataGenerator(np.asarray([sentence_tokens_1]),
                          np.asarray([[LABEL_1]]),
                          features=np.asarray([[[WORD_3], [WORD_4]]]),
                          additional_token_feature_indices=[0],
                          preprocessor=preprocessor,
                          embeddings=embeddings,
                          **DEFAULT_ARGS)[0]
     LOGGER.debug('item: %s', item)
     assert len(item) == 2
     x, labels = item
     assert all_close(labels, get_label_indices([LABEL_1]))
     assert all_close(
         x[0],
         np.concatenate((get_word_vectors(sentence_tokens_1),
                         get_word_vectors(feature_tokens_1)),
                        axis=-1))
     assert all_close(
         x[1],
         np.concatenate(([get_words_char_indices(sentence_tokens_1)
                          ], [get_words_char_indices(feature_tokens_1)]),
                        axis=-1))
     assert all_close(x[-1], [len(sentence_tokens_1)])
 def test_should_be_able_to_get_item(self, preprocessor, embeddings):
     item = DataGenerator(np.asarray([SENTENCE_TOKENS_1]),
                          np.asarray([[LABEL_1]]),
                          preprocessor=preprocessor,
                          embeddings=embeddings,
                          **DEFAULT_ARGS)[0]
     LOGGER.debug('item: %s', item)
     assert len(item) == 2
     x, labels = item
     assert all_close(labels, get_label_indices([LABEL_1]))
     assert all_close(x[0], get_word_vectors(SENTENCE_TOKENS_1))
     assert all_close(x[1], [get_words_char_indices(SENTENCE_TOKENS_1)])
     assert all_close(x[-1], [len(SENTENCE_TOKENS_1)])
 def test_should_return_features(self, preprocessor, embeddings):
     preprocessor.return_features = True
     item = DataGenerator(np.asarray([SENTENCE_TOKENS_1]),
                          np.asarray([[LABEL_1]]),
                          features=np.asarray([SENTENCE_FEATURES_1]),
                          preprocessor=preprocessor,
                          embeddings=embeddings,
                          **DEFAULT_ARGS)[0]
     LOGGER.debug('item: %s', item)
     assert len(item) == 2
     x, labels = item
     assert all_close(labels, get_label_indices([LABEL_1]))
     assert all_close(x[0], get_word_vectors(SENTENCE_TOKENS_1))
     assert all_close(x[1], [get_words_char_indices(SENTENCE_TOKENS_1)])
     assert all_close(x[2], [get_transformed_features(SENTENCE_FEATURES_1)])
     assert all_close(x[-1], [len(SENTENCE_TOKENS_1)])
 def test_should_use_dummy_word_embeddings_if_disabled(self, preprocessor):
     preprocessor.return_casing = False
     item = DataGenerator(np.asarray([SENTENCE_TOKENS_1]),
                          np.asarray([[LABEL_1]]),
                          preprocessor=preprocessor,
                          use_word_embeddings=False,
                          embeddings=None,
                          **DEFAULT_ARGS)[0]
     LOGGER.debug('item: %s', item)
     assert len(item) == 2
     x, labels = item
     assert all_close(labels, get_label_indices([LABEL_1]))
     assert all_close(
         x[0], [np.zeros((len(SENTENCE_TOKENS_1), 0), dtype='float32')])
     assert all_close(x[1], [get_words_char_indices(SENTENCE_TOKENS_1)])
     assert all_close(x[-1], [len(SENTENCE_TOKENS_1)])
 def create_data_generator(self, *args, name_suffix: str,
                           **kwargs) -> DataGenerator:
     return DataGenerator(  # type: ignore
         *args,
         batch_size=self.training_config.batch_size,
         input_window_stride=self.training_config.input_window_stride,
         stateful=self.model_config.stateful,
         preprocessor=self.preprocessor,
         additional_token_feature_indices=self.model_config.
         additional_token_feature_indices,
         text_feature_indices=self.model_config.text_feature_indices,
         concatenated_embeddings_token_count=(
             self.model_config.concatenated_embeddings_token_count),
         char_embed_size=self.model_config.char_embedding_size,
         is_deprecated_padded_batch_text_list_enabled=(
             self.model_config.is_deprecated_padded_batch_text_list_enabled
         ),
         max_sequence_length=self.model_config.max_sequence_length,
         embeddings=self.embeddings,
         name='%s.%s' % (self.model_config.model_name, name_suffix),
         **kwargs)
 def create_eval_data_generator(self, *args, **kwargs) -> DataGenerator:
     return DataGenerator(  # type: ignore
         *args,
         batch_size=(
             self.eval_batch_size
             or self.training_config.batch_size
         ),
         preprocessor=self.p,
         additional_token_feature_indices=self.model_config.additional_token_feature_indices,
         text_feature_indices=self.model_config.text_feature_indices,
         concatenated_embeddings_token_count=(
             self.model_config.concatenated_embeddings_token_count
         ),
         char_embed_size=self.model_config.char_embedding_size,
         is_deprecated_padded_batch_text_list_enabled=(
             self.model_config.is_deprecated_padded_batch_text_list_enabled
         ),
         max_sequence_length=self.eval_max_sequence_length,
         embeddings=self.embeddings,
         **kwargs
     )
    def test_should_generate_windows_and_disable_shuffle(
            self, preprocessor, embeddings):
        batches = DataGenerator(np.asarray([[WORD_1, WORD_2, WORD_3]]),
                                np.asarray([[LABEL_1, LABEL_2, LABEL_3]]),
                                preprocessor=preprocessor,
                                embeddings=embeddings,
                                input_window_stride=2,
                                max_sequence_length=2,
                                **{
                                    **DEFAULT_ARGS, 'shuffle': True
                                })
        assert not batches.shuffle

        LOGGER.debug('batches: %s', batches)
        assert len(batches) == 2

        batch_0 = batches[0]
        LOGGER.debug('batch_0: %s', batch_0)
        assert len(batch_0) == 2
        x, labels = batch_0
        LOGGER.debug('labels: %s', labels)
        assert all_close(labels, get_label_indices([LABEL_1, LABEL_2]))
        assert all_close(x[0], get_word_vectors([WORD_1, WORD_2]))
        assert all_close(x[1], [get_words_char_indices([WORD_1, WORD_2])])
        assert all_close(x[-1], [2])

        batch_1 = batches[1]
        LOGGER.debug('batch_1: %s', batch_1)
        # due to extend, the minimum length is 2
        assert len(batch_1) == 2
        x, labels = batch_1
        LOGGER.debug('labels: %s', labels)
        assert all_close(labels, get_label_indices([LABEL_3]))
        assert all_close(x[0], get_word_vectors([WORD_3, None]))
        assert all_close(x[1], [get_words_char_indices([WORD_3, None])])
        assert all_close(x[-1], [1])
def iter_predict_texts_with_sliding_window_if_enabled(
        texts: List[Union[str, List[str]]],
        model_config: ModelConfig,
        preprocessor: WordPreprocessor,
        max_sequence_length: Optional[int],
        model,
        input_window_stride: int = None,
        embeddings: Embeddings = None,
        features: List[List[List[str]]] = None):
    if not texts:
        LOGGER.info('passed in empty texts, model: %s',
                    model_config.model_name)
        return
    should_tokenize = (
        len(texts) > 0  # pylint: disable=len-as-condition
        and isinstance(texts[0], str))

    if not should_tokenize and max_sequence_length:
        max_actual_sequence_length = max(len(text) for text in texts)
        if max_actual_sequence_length <= max_sequence_length:
            LOGGER.info(
                'all text sequences below max sequence length: %d <= %d (model: %s)',
                max_actual_sequence_length, max_sequence_length,
                model_config.model_name)
        elif model_config.stateful:
            LOGGER.info((
                'some text sequences exceed max sequence length (using sliding windows):'
                ' %d > %d (model: %s)'), max_actual_sequence_length,
                        max_sequence_length, model_config.model_name)
        else:
            LOGGER.info(
                ('some text sequences exceed max sequence length'
                 ' (truncate, model is not stateful): %d > %d (model: %s)'),
                max_actual_sequence_length, max_sequence_length,
                model_config.model_name)

    predict_generator = DataGenerator(
        x=texts,
        y=None,
        batch_size=model_config.batch_size,
        preprocessor=preprocessor,
        additional_token_feature_indices=model_config.
        additional_token_feature_indices,
        text_feature_indices=model_config.text_feature_indices,
        concatenated_embeddings_token_count=(
            model_config.concatenated_embeddings_token_count),
        char_embed_size=model_config.char_embedding_size,
        is_deprecated_padded_batch_text_list_enabled=(
            model_config.is_deprecated_padded_batch_text_list_enabled),
        max_sequence_length=max_sequence_length,
        input_window_stride=input_window_stride,
        stateful=model_config.stateful,
        embeddings=embeddings,
        tokenize=should_tokenize,
        shuffle=False,
        features=features,
        name='%s.predict_generator' % model_config.model_name)

    prediction_list_list: List[List[np.ndarray]] = [[] for _ in texts]
    batch_window_indices_and_offsets_iterable = logging_tqdm(
        iter_batch_window_indices_and_offsets(predict_generator),
        logger=LOGGER,
        total=len(predict_generator),
        desc='%s: ' % predict_generator.name,
        unit='batch')
    completed_curser = 0
    for batch_window_indices_and_offsets in batch_window_indices_and_offsets_iterable:
        LOGGER.debug('predict batch_window_indices_and_offsets: %s',
                     batch_window_indices_and_offsets)
        generator_output = predict_generator.get_window_batch_data(
            batch_window_indices_and_offsets)
        LOGGER.debug('predict on batch: %d',
                     len(batch_window_indices_and_offsets))
        batch_predictions = model.predict_on_batch(generator_output[0])
        LOGGER.debug('preds.shape: %s', batch_predictions.shape)
        for window_indices_and_offsets, seq_predictions in zip(
                batch_window_indices_and_offsets, batch_predictions):
            text_index, text_offset = window_indices_and_offsets
            current_prediction_list = prediction_list_list[text_index]
            LOGGER.debug('prediction_list_list[%d]: %s', text_index,
                         current_prediction_list)
            current_offset = sum((len(a) for a in current_prediction_list))
            if current_offset > text_offset:
                # skip over the overlapping window
                seq_predictions = seq_predictions[(current_offset -
                                                   text_offset):, :]
                text_offset = current_offset
            assert (current_offset == text_offset
                    ), "expected %d to be %d" % (current_offset, text_offset)
            current_prediction_list.append(seq_predictions)
            next_offset = sum((len(a) for a in current_prediction_list))
            is_complete = (next_offset >= len(texts[text_index]))
            LOGGER.debug(
                'is_complete: %s, text_index=%d, completed_curser=%d, next_offset=%d, textlen=%d',
                is_complete, text_index, completed_curser, next_offset,
                len(texts[text_index]))
            if (is_complete and text_index == completed_curser):
                yield np.concatenate(current_prediction_list, axis=0)
                completed_curser += 1

    for prediction_list in prediction_list_list[completed_curser:]:
        yield np.concatenate(prediction_list, axis=0)
def iter_batch_window_indices_and_offsets(
        data_generator: DataGenerator) -> Iterable[List[Tuple[int, int]]]:
    return (data_generator.get_batch_window_indices_and_offsets(batch_index)
            for batch_index in range(len(data_generator)))
 def test_should_be_able_to_instantiate(self, preprocessor, embeddings):
     DataGenerator(np.asarray([[WORD_1, WORD_2]]),
                   np.asarray([[LABEL_1]]),
                   preprocessor=preprocessor,
                   embeddings=embeddings,
                   **DEFAULT_ARGS)