def test_predict_iter(self):
        data = [{
            "text": "The laptop case was great and cover was rubbish"
        }, {
            "text": "Another day at the office"
        }, {
            "text": "The laptop case was great and cover was rubbish"
        }]
        # Test that it raises an Error when the model attribute is not None
        model_dir = self.TARGET_EXTRACTION_MODEL
        model = AllenNLPModel('TE', self.CONFIG_FILE, 'target-tagger',
                              model_dir)
        with pytest.raises(AssertionError):
            for _ in model._predict_iter(data):
                pass
        # Test that it raises an Error when the data provided is not a list or
        # iterable
        model.load()
        non_iter_data = 5
        with pytest.raises(TypeError):
            for _ in model._predict_iter(non_iter_data):
                pass
        # Test that it works on the normal cases which are lists and iterables
        for data_type in [data, iter(data)]:
            predictions = []
            for prediction in model._predict_iter(data_type):
                predictions.append(prediction)
            assert 3 == len(predictions)
            assert isinstance(predictions[0], dict)
            assert 5 == len(predictions[1]['tags'])
            assert 9 == len(predictions[1]['class_probabilities'])

        # Test that it works on a larger dataset of 150
        larger_dataset = data * 50
        for data_type in [larger_dataset, iter(larger_dataset)]:
            predictions = []
            for prediction in model._predict_iter(data_type):
                predictions.append(prediction)
            assert 150 == len(predictions)
            assert isinstance(predictions[0], dict)
            assert 5 == len(predictions[-2]['tags'])
            assert 9 == len(predictions[-2]['class_probabilities'])
            assert 9 == len(predictions[-1]['tags'])
            assert 9 == len(predictions[-1]['class_probabilities'])

        # Test the case when you feed it no data which can happen through
        # multiple iterators e.g.
        alt_data = iter(data)
        # ensure alt_data has no data
        assert 3 == len([d for d in alt_data])
        predictions = []
        for prediction in model._predict_iter(alt_data):
            predictions.append(prediction)
        assert not predictions
예제 #2
0
    def test_predict_iter(self, batch_size: Optional[int],
                          yield_original_target: bool):
        data = [{
            "text": "The laptop case was great and cover was rubbish"
        }, {
            "text": "Another day at the office"
        }, {
            "text": "The laptop case was great and cover was rubbish"
        }]
        # Test that it raises an Error when the model attribute is not None
        model_dir = self.TARGET_EXTRACTION_MODEL
        model = AllenNLPModel('TE', self.CONFIG_FILE, 'target-tagger',
                              model_dir)
        with pytest.raises(AssertionError):
            for _ in model._predict_iter(
                    data,
                    batch_size=batch_size,
                    yield_original_target=yield_original_target):
                pass
        # Test that it raises an Error when the data provided is not a list or
        # iterable
        model.load()
        non_iter_data = 5
        with pytest.raises(TypeError):
            for _ in model._predict_iter(
                    non_iter_data,
                    batch_size=batch_size,
                    yield_original_target=yield_original_target):
                pass
        # Test that it works on the normal cases which are lists and iterables
        for data_type in [data, iter(data)]:
            predictions = []
            for prediction in model._predict_iter(
                    data_type,
                    batch_size=batch_size,
                    yield_original_target=yield_original_target):
                predictions.append(prediction)
            assert 3 == len(predictions)
            predictions_0 = predictions[0]
            predictions_1 = predictions[1]

            if yield_original_target:
                assert isinstance(predictions_0, tuple)
                for pred_index, original_data_dict in enumerate(predictions):
                    _, original_data_dict = original_data_dict
                    assert len(data[pred_index]) == len(original_data_dict)
                    for key, value in data[pred_index].items():
                        assert value == original_data_dict[key]
                predictions_0 = predictions_0[0]
                predictions_1 = predictions_1[0]
            assert isinstance(predictions_0, dict)
            assert 6 == len(predictions_1)
            assert 5 == len(predictions_1['tags'])
            assert 9 == len(predictions_1['class_probabilities'])

            correct_text_1 = "Another day at the office"
            correct_tokens_1 = correct_text_1.split()
            assert correct_tokens_1 == predictions_1['words']
            assert correct_text_1 == predictions_1['text']

        # Test that it works on a larger dataset of 150
        larger_dataset = data * 50
        for data_type in [larger_dataset, iter(larger_dataset)]:
            predictions = []
            for prediction in model._predict_iter(
                    data_type,
                    batch_size=batch_size,
                    yield_original_target=yield_original_target):
                predictions.append(prediction)
            assert 150 == len(predictions)
            predictions_0 = predictions[0]
            predictions_1 = predictions[-1]
            predictions_2 = predictions[-2]
            if yield_original_target:
                predictions_0 = predictions_0[0]
                predictions_1 = predictions_1[0]
                predictions_2 = predictions_2[0]
            assert isinstance(predictions_0, dict)
            assert 5 == len(predictions_2['tags'])
            assert 9 == len(predictions_2['class_probabilities'])
            assert 9 == len(predictions_1['tags'])
            assert 9 == len(predictions_1['class_probabilities'])

        # Test the case when you feed it no data which can happen through
        # multiple iterators e.g.
        alt_data = iter(data)
        # ensure alt_data has no data
        assert 3 == len([d for d in alt_data])
        predictions = []
        for prediction in model._predict_iter(
                alt_data,
                batch_size=batch_size,
                yield_original_target=yield_original_target):
            predictions.append(prediction)
        assert not predictions