예제 #1
0
    def test_param_validation(self):
        # Make sure all parameters can be altered. Make sure non-valid params
        # are caught
        parameters = {
            'regex_patterns': {
                'PAD': [r'\W'],
                'BACKGROUND': [r'\w']
            },
            'encapsulators': {
                'start': r'(?<![\w.\$\%\-])',
                'end': r'(?:(?=(\b|[ ]))|(?=[^\w\%\$]([^\w]|$))|$)',
            },
            'ignore_case': True,
            'default_label': 'BACKGROUND',
        }
        invalid_parameters = [
            {'regex_patterns': -1, 'encapsulators': "words",
             'default_label': None, 'ignore_case': None},
            {'regex_patterns': 1, 'encapsulators': None,
             'default_label': [], 'ignore_case': 'true'},
            {'regex_patterns': 1, 'encapsulators': tuple(),
             'default_label': -1, 'ignore_case': 2},
            {'regex_patterns': None, 'encapsulators': 3,
             'default_label': 1.2, 'ignore_case': -1},
            {'regex_patterns': 2.2, 'encapsulators': 3,
             'default_label': None, 'ignore_case': {}},
        ]
        model = RegexModel(label_mapping=self.label_mapping,
                           parameters=parameters)
        self.assertDictEqual(parameters, model._parameters)

        for invalid_param_set in invalid_parameters:
            with self.assertRaises(ValueError):
                RegexModel(label_mapping=self.label_mapping,
                           parameters=invalid_param_set)
예제 #2
0
    def test_labels(self, *mocks):

        # load default
        model = RegexModel(self.label_mapping)

        labels = ['PAD', 'BACKGROUND', 'ADDRESS']

        self.assertListEqual(labels, model.labels)
예제 #3
0
    def test_save(self, mock_open, *mocks):

        # setup mock
        mock_file = setup_save_mock_open(mock_open)

        # Save and load a Model with custom parameters
        parameters = {
            'regex_patterns': {
                'PAD': [r'\W'],
                'BACKGROUND': [r'\w']
            },
            'encapsulators': {
                'start': r'(?<![\w.\$\%\-])',
                'end': r'(?:(?=(\b|[ ]))|(?=[^\w\%\$]([^\w]|$))|$)',
            },
            'ignore_case': True,
            'default_label': 'BACKGROUND',
        }
        label_mapping = {
            'PAD': 0,
            'CITY': 1,  # SAME AS BACKGROUND
            'BACKGROUND': 1,
            'ADDRESS': 2,
        }
        model = RegexModel(label_mapping, parameters)

        # save and test
        model.save_to_disk(".")
        self.assertEqual(
            # model parameters
            '{"regex_patterns": {"PAD": ["\\\\W"], "BACKGROUND": ["\\\\w"]}, '
            '"encapsulators": {"start": "(?<![\\\\w.\\\\$\\\\%\\\\-])", '
            '"end": '
            '"(?:(?=(\\\\b|[ ]))|(?=[^\\\\w\\\\%\\\\$]([^\\\\w]|$))|$)"}, '
            '"ignore_case": true, "default_label": "BACKGROUND"}'
            
            # label mapping
            '{"PAD": 0, "CITY": 1, "BACKGROUND": 1, "ADDRESS": 2}',
            mock_file.getvalue())

        # close mock
        StringIO.close(mock_file)
예제 #4
0
    def test_load(self, *mocks):
        dir = os.path.join(
            _resource_labeler_dir,
            'regex_model/')
        loaded_model = RegexModel.load_from_disk(dir)
        self.assertIsInstance(loaded_model, RegexModel)

        self.assertEqual(mock_model_parameters['encapsulators'],
                         loaded_model._parameters['encapsulators'])
        self.assertEqual(mock_model_parameters['regex_patterns'],
                         loaded_model._parameters['regex_patterns'])
예제 #5
0
    def test_reverse_label_mapping(self, *mocks):

        # load default
        model = RegexModel(self.label_mapping)

        # should notice that CITY does not exist in reverse
        reverse_label_mapping = {
            0: 'PAD',
            1: 'BACKGROUND',
            2: 'ADDRESS'}

        self.assertDictEqual(reverse_label_mapping,
                             model.reverse_label_mapping)
예제 #6
0
    def test_set_label_mapping(self, *mocks):

        # load default
        model = RegexModel(self.label_mapping)

        # test not dict
        label_mapping = None
        with self.assertRaisesRegex(
                ValueError, "`label_mapping` must be a dict which maps labels "
                            "to index encodings."):
            model.set_label_mapping(label_mapping)

        # test label_mapping
        label_mapping = {
            'PAD': 0,
            'CITY': 1,  # SAME AS BACKGROUND
            'BACKGROUND': 1,
            'ADDRESS': 2,
        }
        model.set_label_mapping(label_mapping)
        self.assertDictEqual(label_mapping, model.label_mapping)
예제 #7
0
    def test_label_mapping(self, *mocks):

        # load default
        model = RegexModel(self.label_mapping)

        self.assertDictEqual(self.label_mapping, model.label_mapping)
예제 #8
0
    def test_predict(self, mock_stdout):
        parameters = {
            'regex_patterns': {
                'PAD': [r'\W'],
                'BACKGROUND': [r'\w']
            },
            'ignore_case': True,
            'default_label': 'BACKGROUND',
        }
        model = RegexModel(label_mapping=self.label_mapping,
                           parameters=parameters)

        # test only pad and background separate
        expected_output = {
            'pred': [np.array([[1, 0, 0],
                               [1, 0, 0],
                               [1, 0, 0]]),
                     np.array([[0, 1, 0],
                               [0, 1, 0],
                               [0, 1, 0],
                               [0, 1, 0],
                               [0, 1, 0]])
         ]}
        model_output = model.predict(['   ', 'hello'])
        self.assertIn('pred', model_output)
        for expected, output in zip(expected_output['pred'],
                                    model_output['pred']):
            self.assertTrue(np.array_equal(expected, output))

        # check verbose printing
        self.assertIn('Data Samples', mock_stdout.getvalue())

        # test pad with background
        expected_output = {
            'pred': [np.array([[1, 0, 0],
                               [0, 1, 0],
                               [1, 0, 0],
                               [0, 1, 0],
                               [1, 0, 0]])
         ]}
        model_output = model.predict([' h w.'])
        self.assertIn('pred', model_output)
        for expected, output in zip(expected_output['pred'],
                                    model_output['pred']):
            self.assertTrue(np.array_equal(expected, output))

        # test show confidences
        expected_output = {
            'pred': [np.array([[1, 0, 0],
                               [0, 1, 0],
                               [1, 0, 0],
                               [0, 1, 0],
                               [1, 0, 0]])
                     ],
            'conf': [np.array([[1, 0, 0],
                               [0, 1, 0],
                               [1, 0, 0],
                               [0, 1, 0],
                               [1, 0, 0]])
                     ]
        }
        model_output = model.predict([' h w.'], show_confidences=True)
        self.assertIn('pred', model_output)
        self.assertIn('conf', model_output)
        for expected, output in zip(expected_output['pred'],
                                    model_output['pred']):
            self.assertTrue(np.array_equal(expected, output))
        for expected, output in zip(expected_output['conf'],
                                    model_output['conf']):
            self.assertTrue(np.array_equal(expected, output))

        # test verbose = False
        # clear stdout
        mock_stdout.seek(0)
        mock_stdout.truncate(0)
        model_output = model.predict(['hello world.'], verbose=False)
        self.assertNotIn('Data Samples', mock_stdout.getvalue())
예제 #9
0
 def test_help(self, mock_stdout):
     RegexModel.help()
     self.assertIn("RegexModel", mock_stdout.getvalue())
     self.assertIn("Parameters", mock_stdout.getvalue())