Пример #1
0
    def test_check_pipeline(self, mock_open, mock_load_model,
                            mock_base_processor):
        self._setup_mock_load_model(mock_load_model)
        self._setup_mock_load_processor(mock_base_processor)

        data_labeler = UnstructuredDataLabeler()

        # check pipeline with no errors w overlap
        data_labeler._model = mock.Mock()
        data_labeler._model.get_parameters.return_value = dict(a=1)
        data_labeler._model.get_parameters.return_value = dict(a=1)
        data_labeler._preprocessor.get_parameters.return_value = dict(a=1)
        data_labeler._postprocessor.get_parameters.return_value = dict(a=1)
        with warnings.catch_warnings(record=True) as w:
            data_labeler.check_pipeline()
        self.assertEqual(0, len(w))  # assert no warnings raised

        # invalid pipeline, model != preprocessor
        data_labeler._model.get_parameters.return_value = dict(a=1)
        data_labeler.preprocessor.get_parameters.return_value = dict(a=2)
        with self.assertWarnsRegex(
                RuntimeWarning, 'Model and preprocessor value for `a` do '
                'not match. 1 != 2'):
            data_labeler.check_pipeline()

        # invalid pipeline, model != postprocessor
        data_labeler.preprocessor.get_parameters.return_value = dict(a=1)
        data_labeler.postprocessor.get_parameters.return_value = dict(a=2)
        with self.assertWarnsRegex(
                RuntimeWarning, 'Model and postprocessor value for `a` do '
                'not match. 1 != 2'):
            data_labeler.check_pipeline()

        # invalid pipeline, preprocessor != postprocessor
        data_labeler._model = mock.Mock()
        data_labeler._model.get_parameters.return_value = dict(a=1)
        data_labeler.preprocessor.get_parameters.return_value = dict(a=1, b=1)
        data_labeler.postprocessor.get_parameters.return_value = dict(a=1, b=2)
        with self.assertWarnsRegex(
                RuntimeWarning, 'Preprocessor and postprocessor value for '
                '`b` do not match. 1 != 2'):
            data_labeler.check_pipeline()

        # valid pipeline, preprocessor != postprocessor but skips processor
        data_labeler._model = mock.Mock()
        data_labeler._model.get_parameters.return_value = dict(a=1)
        data_labeler.preprocessor.get_parameters.return_value = dict(a=1, b=1)
        data_labeler.postprocessor.get_parameters.return_value = dict(a=1, b=2)
        with warnings.catch_warnings(record=True) as w:
            data_labeler.check_pipeline(skip_postprocessor=True)
        self.assertEqual(0, len(w))

        # assert raises error instead of warning
        with self.assertRaisesRegex(
                RuntimeError, 'Preprocessor and postprocessor value for '
                '`b` do not match. 1 != 2'):
            data_labeler.check_pipeline(error_on_mismatch=True)
Пример #2
0
    def test_labels(self, mock_open, mock_load_model, mock_load_processor,
                    *mocks):

        self._setup_mock_load_model(mock_load_model)
        self._setup_mock_load_processor(mock_load_processor)

        # load default
        labels = ["PAD", "UNKNOWN", "ADDRESS", "PERSON"]
        data_labeler = UnstructuredDataLabeler()

        self.assertListEqual(labels, data_labeler.labels)
Пример #3
0
    def test_labels(self, mock_open, mock_load_model, mock_load_processor,
                    *mocks):

        self._setup_mock_load_model(mock_load_model)
        self._setup_mock_load_processor(mock_load_processor)

        # load default
        labels = ['PAD', 'BACKGROUND', 'ADDRESS', 'PERSON']
        data_labeler = UnstructuredDataLabeler()

        self.assertListEqual(labels, data_labeler.labels)
Пример #4
0
    def test_load_labeler(self, mock_open, mock_load_model,
                          mock_base_processor):

        self._setup_mock_load_model(mock_load_model)
        self._setup_mock_load_processor(mock_base_processor)

        # load default
        data_labeler = UnstructuredDataLabeler()

        self.assertDictEqual(data_labeler.label_mapping,
                             data_labeler_parameters["label_mapping"])
        self.assertListEqual(data_labeler.labels,
                             ["PAD", "UNKNOWN", "ADDRESS", "PERSON"])
        self.assertIsInstance(data_labeler.preprocessor,
                              data_processing.BaseDataPreprocessor)
        self.assertIsInstance(data_labeler.postprocessor,
                              data_processing.BaseDataPostprocessor)
Пример #5
0
    def test_load_labeler(self, mock_open, mock_load_model,
                          mock_base_processor):

        self._setup_mock_load_model(mock_load_model)
        self._setup_mock_load_processor(mock_base_processor)

        # load default
        data_labeler = UnstructuredDataLabeler()

        self.assertDictEqual(data_labeler.label_mapping,
                             data_labeler_parameters['label_mapping'])
        self.assertListEqual(data_labeler.labels,
                             ['PAD', 'BACKGROUND', 'ADDRESS', 'PERSON'])
        self.assertIsInstance(data_labeler.preprocessor,
                              data_processing.BaseDataPreprocessor)
        self.assertIsInstance(data_labeler.postprocessor,
                              data_processing.BaseDataPostprocessor)
Пример #6
0
    def test_set_labels(self, mock_open, mock_load_model, mock_load_processor,
                        *mocks):

        self._setup_mock_load_model(mock_load_model)
        self._setup_mock_load_processor(mock_load_processor)

        data_labeler = UnstructuredDataLabeler()
        # test label list to label_mapping
        labels = ["a", "b", "d", "c"]
        data_labeler.set_labels(labels)
        mock_load_model.return_value.set_label_mapping.assert_called_with(
            label_mapping=["a", "b", "d", "c"])

        # test label dict to label_mapping
        labels = dict(b=1, c=2, d=3, e=4)
        data_labeler.set_labels(labels)
        mock_load_model.return_value.set_label_mapping.assert_called_with(
            label_mapping=dict(b=1, c=2, d=3, e=4))
Пример #7
0
    def test_reverse_label_mappings(self, mock_open, mock_load_model,
                                    mock_load_processor, *mocks):

        self._setup_mock_load_model(mock_load_model)
        self._setup_mock_load_processor(mock_load_processor)

        # load default
        data_labeler = UnstructuredDataLabeler()

        reverse_label_mapping = {
            0: "PAD",
            1: "UNKNOWN",
            2: "ADDRESS",
            3: "PERSON",
        }

        self.assertDictEqual(reverse_label_mapping,
                             data_labeler.reverse_label_mapping)
Пример #8
0
    def test_set_preprocessor(self, mock_open, mock_load_model,
                              mock_load_processor, *mocks):
        self._setup_mock_load_model(mock_load_model)
        self._setup_mock_load_processor(mock_load_processor)

        # load default
        data_labeler = UnstructuredDataLabeler()

        # test setting preprocessor
        processor_mock = mock.Mock(spec=data_processing.CharPreprocessor)
        data_labeler.set_preprocessor(processor_mock)
        self.assertEqual(processor_mock, data_labeler.preprocessor)

        # test failure bc not processing object
        with self.assertRaisesRegex(
                TypeError, 'The specified preprocessor was not of the '
                'correct type, `DataProcessing`.'):
            data_labeler.set_preprocessor(1)
Пример #9
0
    def test_set_model(self, mock_open, mock_load_model, mock_load_processor,
                       *mocks):
        self._setup_mock_load_model(mock_load_model)
        self._setup_mock_load_processor(mock_load_processor)

        # load default
        data_labeler = UnstructuredDataLabeler()

        # test setting model
        model_mock = mock.Mock(spec=CharacterLevelCnnModel)
        data_labeler.set_model(model_mock)
        self.assertEqual(model_mock, data_labeler.model)

        # test failure bc not model object
        with self.assertRaisesRegex(
                TypeError, 'The specified model was not of the correct'
                ' type, `BaseModel`.'):
            data_labeler.set_model(1)
Пример #10
0
    def test_reverse_label_mappings(self, mock_open, mock_load_model,
                                    mock_load_processor, *mocks):

        self._setup_mock_load_model(mock_load_model)
        self._setup_mock_load_processor(mock_load_processor)

        # load default
        data_labeler = UnstructuredDataLabeler()

        reverse_label_mapping = {
            0: 'PAD',
            1: 'BACKGROUND',
            2: 'ADDRESS',
            3: 'PERSON',
        }

        self.assertDictEqual(reverse_label_mapping,
                             data_labeler.reverse_label_mapping)
Пример #11
0
    def test_save_to_disk(self, mock_open, mock_load_model,
                          mock_load_processor, *mocks):

        self._setup_mock_load_model(mock_load_model)
        self._setup_mock_load_processor(mock_load_processor)

        # call func
        data_labeler = UnstructuredDataLabeler()

        # setup save mock
        mock_file = setup_save_mock_open(mock_open)

        # save and test
        data_labeler.save_to_disk('test/path')
        self.assertEqual(
            '{"model": {"class": "CharacterLevelCnnModel"}, '
            '"preprocessor": {"class": "CharPreprocessor"}, '
            '"postprocessor": {"class": "CharPostprocessor"}}',
            mock_file.getvalue())

        # close mock
        StringIO.close(mock_file)
Пример #12
0
    def test_set_params(self, mock_open, mock_load_model, mock_base_processor):

        self._setup_mock_load_model(mock_load_model)
        self._setup_mock_load_processor(mock_base_processor)

        # load default
        data_labeler = UnstructuredDataLabeler()

        # check empty sent
        with self.assertRaisesRegex(
                ValueError,
                re.escape("The params dict must have the following "
                          "format:\nparams=dict(preprocessor=dict(..."
                          "), model=dict(...), postprocessor=dict(..."
                          ")), where each sub-dict contains "
                          "parameters of the specified data_labeler "
                          "pipeline components."),
        ):
            data_labeler.set_params(None)

        with self.assertRaisesRegex(
                ValueError,
                re.escape("The params dict must have the following "
                          "format:\nparams=dict(preprocessor=dict(..."
                          "), model=dict(...), postprocessor=dict(..."
                          ")), where each sub-dict contains "
                          "parameters of the specified data_labeler "
                          "pipeline components."),
        ):
            data_labeler.set_params({})

        # test if invalid key sent
        with self.assertRaisesRegex(
                ValueError,
                re.escape("The params dict must have the following "
                          "format:\nparams=dict(preprocessor=dict(..."
                          "), model=dict(...), postprocessor=dict(..."
                          ")), where each sub-dict contains "
                          "parameters of the specified data_labeler "
                          "pipeline components."),
        ):
            data_labeler.set_params({"bad key": None})

        # validate no errors occur when correct params are sent
        data_labeler._preprocessor.get_parameters.return_value = dict()
        data_labeler._model.get_parameters.return_value = dict()
        data_labeler._postprocessor.get_parameters.return_value = dict()

        data_labeler.set_params({
            "preprocessor": {
                "test": 1
            },
            "model": {
                "test": 1
            },
            "postprocessor": {
                "test2": 3
            },
        })

        # validate warning on overlaps.
        # here we presume parameters are set as dict(test=1), dict(test=2)
        data_labeler._preprocessor.get_parameters.return_value = dict(test=1)
        data_labeler._model.get_parameters.return_value = dict(test=2)
        with self.assertWarnsRegex(
                RuntimeWarning,
                "Model and preprocessor value for `test` do "
                "not match. 2 != 1",
        ):
            data_labeler.set_params({
                "preprocessor": {
                    "test": 1
                },
                "model": {
                    "test": 2
                }
            })

        # check if param sent for missing pipeline component
        data_labeler._preprocessor = None
        with self.assertRaisesRegex(
                ValueError,
                "Parameters for the preprocessor, model, or"
                " postprocessor were specified when one or "
                "more of these were not set in the "
                "DataLabeler.",
        ):
            data_labeler.set_params({"preprocessor": {"test": 1}})

        data_labeler._model = None
        with self.assertRaisesRegex(
                ValueError,
                "Parameters for the preprocessor, model, or"
                " postprocessor were specified when one or "
                "more of these were not set in the "
                "DataLabeler.",
        ):
            data_labeler.set_params({"model": {"test": 1}})

        data_labeler._postprocessor = None
        with self.assertRaisesRegex(
                ValueError,
                "Parameters for the preprocessor, model, or"
                " postprocessor were specified when one or "
                "more of these were not set in the "
                "DataLabeler.",
        ):
            data_labeler.set_params({"postprocessor": {"test": 1}})
Пример #13
0
    def test_set_params(self, mock_open, mock_load_model, mock_base_processor):

        self._setup_mock_load_model(mock_load_model)
        self._setup_mock_load_processor(mock_base_processor)

        # load default
        data_labeler = UnstructuredDataLabeler()

        # check empty sent
        with self.assertRaisesRegex(
                ValueError,
                re.escape('The params dict must have the following '
                          'format:\nparams=dict(preprocessor=dict(...'
                          '), model=dict(...), postprocessor=dict(...'
                          ')), where each sub-dict contains '
                          'parameters of the specified data_labeler '
                          'pipeline components.')):
            data_labeler.set_params(None)

        with self.assertRaisesRegex(
                ValueError,
                re.escape('The params dict must have the following '
                          'format:\nparams=dict(preprocessor=dict(...'
                          '), model=dict(...), postprocessor=dict(...'
                          ')), where each sub-dict contains '
                          'parameters of the specified data_labeler '
                          'pipeline components.')):
            data_labeler.set_params({})

        # test if invalid key sent
        with self.assertRaisesRegex(
                ValueError,
                re.escape('The params dict must have the following '
                          'format:\nparams=dict(preprocessor=dict(...'
                          '), model=dict(...), postprocessor=dict(...'
                          ')), where each sub-dict contains '
                          'parameters of the specified data_labeler '
                          'pipeline components.')):
            data_labeler.set_params({'bad key': None})

        # validate no errors occur when correct params are sent
        data_labeler._preprocessor.get_parameters.return_value = dict()
        data_labeler._model.get_parameters.return_value = dict()
        data_labeler._postprocessor.get_parameters.return_value = dict()

        data_labeler.set_params({
            'preprocessor': {
                'test': 1
            },
            'model': {
                'test': 1
            },
            'postprocessor': {
                'test2': 3
            }
        })

        # validate warning on overlaps.
        # here we presume parameters are set as dict(test=1), dict(test=2)
        data_labeler._preprocessor.get_parameters.return_value = dict(test=1)
        data_labeler._model.get_parameters.return_value = dict(test=2)
        with self.assertWarnsRegex(
                RuntimeWarning, 'Model and preprocessor value for `test` do '
                'not match. 2 != 1'):
            data_labeler.set_params({
                'preprocessor': {
                    'test': 1
                },
                'model': {
                    'test': 2
                }
            })

        # check if param sent for missing pipeline component
        data_labeler._preprocessor = None
        with self.assertRaisesRegex(
                ValueError, 'Parameters for the preprocessor, model, or'
                ' postprocessor were specified when one or '
                'more of these were not set in the '
                'DataLabeler.'):
            data_labeler.set_params({'preprocessor': {'test': 1}})

        data_labeler._model = None
        with self.assertRaisesRegex(
                ValueError, 'Parameters for the preprocessor, model, or'
                ' postprocessor were specified when one or '
                'more of these were not set in the '
                'DataLabeler.'):
            data_labeler.set_params({'model': {'test': 1}})

        data_labeler._postprocessor = None
        with self.assertRaisesRegex(
                ValueError, 'Parameters for the preprocessor, model, or'
                ' postprocessor were specified when one or '
                'more of these were not set in the '
                'DataLabeler.'):
            data_labeler.set_params({'postprocessor': {'test': 1}})