Ejemplo n.º 1
0
    def test_external_1_and_2(self):
        env_orig_value = os.environ.get(self.ENV_VAR, None)

        os.environ[self.ENV_VAR] = OS_ENV_PATH_SEP.join([self.EXT_MOD_1,
                                                         self.EXT_MOD_2])
        m = self.test_get_internal_modules(True)

        self.assertIn('ImplExternal1', m)
        self.assertIn('ImplExternal2', m)
        self.assertIn('ImplExternal3', m)

        self.assertEqual(m['ImplExternal1']().inst_method('d'), 'external1d')
        self.assertEqual(m['ImplExternal2']().inst_method('e'), 'external2e')
        self.assertEqual(m['ImplExternal3']().inst_method('f'), 'external3f')

        if env_orig_value:
            os.environ[self.ENV_VAR] = env_orig_value
        else:
            del os.environ[self.ENV_VAR]
Ejemplo n.º 2
0
    def test_external_1_and_2_and_garbage(self):
        env_orig_value = os.environ.get(self.ENV_VAR, None)

        os.environ[self.ENV_VAR] = OS_ENV_PATH_SEP.join([self.EXT_MOD_1,
                                                         self.EXT_MOD_2,
                                                         'asdgasfhsadf',
                                                         'some thing weird',
                                                         'but still uses sep'])
        m = self.test_get_internal_modules(True)

        self.assertIn('ImplExternal1', m)
        self.assertIn('ImplExternal2', m)
        self.assertIn('ImplExternal3', m)

        self.assertEqual(m['ImplExternal1']().inst_method('d'), 'external1d')
        self.assertEqual(m['ImplExternal2']().inst_method('e'), 'external2e')
        self.assertEqual(m['ImplExternal3']().inst_method('f'), 'external3f')

        if env_orig_value:
            os.environ[self.ENV_VAR] = env_orig_value
        else:
            del os.environ[self.ENV_VAR]
Ejemplo n.º 3
0
class TestClassifierService(unittest.TestCase):

    # noinspection PyUnresolvedReferences
    @mock.patch.dict(
        os.environ, {
            Pluggable.PLUGIN_ENV_VAR:
            OS_ENV_PATH_SEP.join([
                STUB_CLASSIFIER_MOD_PATH,
                'tests.web.classifier_service.dummy_descriptor_generator',
            ])
        })
    def setUp(self):
        super(TestClassifierService, self).setUp()
        self.config = SmqtkClassifierService.get_default_config()

        self.config['classification_factory']['type'] = \
            'MemoryClassificationElement'
        del self.config['classification_factory']['FileClassificationElement']

        del self.config['classifier_collection']['__example_label__']
        self.dummy_label = 'dummy'
        self.config['classifier_collection'][self.dummy_label] = {
            'DummyClassifier': {},
            'type': 'DummyClassifier'
        }
        self.config['immutable_labels'] = [self.dummy_label]

        self.config['descriptor_factory']['type'] = 'DescriptorMemoryElement'
        del self.config['descriptor_factory']['DescriptorFileElement']

        self.config['descriptor_generator'] = {
            'DummyDescriptorGenerator': {},
            'type': 'DummyDescriptorGenerator'
        }

        self.config['iqr_state_classifier_config']['type'] = \
            'DummySupervisedClassifier'

        self.config['enable_classifier_removal'] = True

        self.config['flask_app'] = {}
        del self.config['server']

        self.app = SmqtkClassifierService(json_config=self.config)

    def assertStatus(self, rv, code):
        self.assertEqual(rv.status_code, code)

    def assertResponseMessageRegex(self, rv, regex):
        self.assertRegexpMatches(
            json.loads(rv.data.decode())['message'], regex)

    def assertMessage(self, resp_data, message):
        self.assertEqual(resp_data['message'], message)

    def test_server_alive(self):
        rv = self.app.test_client().get('/is_ready')

        self.assertStatus(rv, 200)
        resp_data = json.loads(rv.data.decode())
        self.assertMessage(resp_data, "Yes, I'm alive!")

    def test_preconfigured_labels(self):
        rv = self.app.test_client().get('/classifier_labels')

        self.assertStatus(rv, 200)
        resp_data = json.loads(rv.data.decode())
        self.assertMessage(resp_data, "Classifier labels.")
        self.assertListEqual(resp_data['labels'], ['dummy'])

    def test_one_classify(self):
        results_exp = dict(positive=0.5, negative=0.5)
        label = 'dummy'
        content_type = 'text/plain'
        element = base64.b64encode(b'TEST ELEMENT').decode()

        with self.app.test_client() as cli:
            rv = cli.post('/classify',
                          data={
                              'content_type': content_type,
                              'bytes_b64': element,
                          })
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Finished classification.")
            self.assertDictEqual(resp_data['result'][label], results_exp)

            # Same classifier, just retrieving it different ways, so skip the
            # correctness check
            rv = cli.post('/classify',
                          data={
                              'label': label,
                              'content_type': content_type,
                              'bytes_b64': element,
                          })
            self.assertStatus(rv, 200)

            rv = cli.post('/classify',
                          data={
                              'label': json.dumps(label),
                              'content_type': content_type,
                              'bytes_b64': element,
                          })
            self.assertStatus(rv, 200)

            rv = cli.post('/classify',
                          data={
                              'label': json.dumps([label]),
                              'content_type': content_type,
                              'bytes_b64': element,
                          })
            self.assertStatus(rv, 200)

    def test_adjusted_classify(self):
        results_exp = dict(positive=0.5, negative=0.5)
        label = 'dummy'
        content_type = 'text/plain'
        element = base64.b64encode(b'TEST ELEMENT').decode()

        with self.app.test_client() as cli:
            rv = cli.post('/classify',
                          data={
                              'content_type': content_type,
                              'bytes_b64': element,
                              'adjustment': json.dumps({
                                  'positive': 0,
                              }),
                          })
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Finished classification.")
            self.assertDictEqual(resp_data['result'][label], results_exp)

            rv = cli.post('/classify',
                          data={
                              'content_type': content_type,
                              'bytes_b64': element,
                              'adjustment': json.dumps({
                                  'positive': -1,
                              }),
                          })
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Finished classification.")
            result = resp_data['result'][label]
            self.assertAlmostEqual(result['positive'], 1 / (1 + math.exp(-1)))
            self.assertAlmostEqual(result['negative'], 1 / (1 + math.exp(1)))

            rv = cli.post('/classify',
                          data={
                              'content_type': content_type,
                              'bytes_b64': element,
                              'adjustment': json.dumps({
                                  'positive': 1,
                              }),
                          })
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Finished classification.")
            result = resp_data['result'][label]
            self.assertAlmostEqual(result['positive'], 1 / (1 + math.exp(1)))
            self.assertAlmostEqual(result['negative'], 1 / (1 + math.exp(-1)))

    def test_multiple_classify(self):
        content_type = 'text/plain'
        element = base64.b64encode(b'TEST ELEMENT').decode()
        results_exp = dict(positive=0.5, negative=0.5)
        pickle_data = pickle.dumps(DummyClassifier.from_config({}))
        enc_data = base64.b64encode(pickle_data)
        old_label = 'dummy'
        new_label = 'dummy2'
        lock_clfr_str = 'true'

        with self.app.test_client() as cli:
            rv = cli.post('/classifier',
                          data={
                              'label': new_label,
                              'lock_label': lock_clfr_str,
                              'bytes_b64': enc_data,
                          })
            self.assertStatus(rv, 201)
            resp_data = json.loads(rv.data.decode())
            self.assertEqual(resp_data["message"],
                             "Uploaded classifier for label '%s'." % new_label)
            self.assertEqual(resp_data["label"], new_label)

            rv = cli.post('/classify',
                          data={
                              'label': new_label,
                              'content_type': content_type,
                              'bytes_b64': element,
                          })
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Finished classification.")
            self.assertDictEqual(resp_data['result'][new_label], results_exp)

            rv = cli.post('/classify',
                          data={
                              'content_type': content_type,
                              'bytes_b64': element,
                          })
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Finished classification.")
            self.assertDictEqual(resp_data['result'][old_label], results_exp)
            self.assertDictEqual(resp_data['result'][new_label], results_exp)

    def test_get_add_del_classifier(self):
        old_label = 'dummy'
        new_label = 'dummy2'

        with self.app.test_client() as cli:
            rv = cli.get('/classifier', data={
                'label': old_label,
            })
            self.assertStatus(rv, 200)
            enc_data = rv.data.decode()

            rv = cli.post('/classifier',
                          data={
                              'label': new_label,
                              'bytes_b64': enc_data,
                          })
            self.assertStatus(rv, 201)
            resp_data = json.loads(rv.data.decode())
            self.assertEqual(resp_data["message"],
                             "Uploaded classifier for label '%s'." % new_label)
            self.assertEqual(resp_data["label"], new_label)

            rv = cli.get('/classifier_labels')
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Classifier labels.")
            self.assertSetEqual(set(resp_data['labels']),
                                {old_label, new_label})

            rv = cli.delete('/classifier', data={
                'label': new_label,
            })
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertEqual(resp_data["message"],
                             "Removed classifier with label '%s'." % new_label)
            self.assertEqual(resp_data["removed_label"], new_label)

            rv = cli.get('/classifier_labels')
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Classifier labels.")
            self.assertSetEqual(set(resp_data['labels']), {old_label})

    def test_add_imm_del_classifier(self):
        pickle_data = pickle.dumps(DummyClassifier.from_config({}))
        enc_data = base64.b64encode(pickle_data).decode('utf8')
        old_label = 'dummy'
        new_label = 'dummy2'
        lock_clfr_str = 'true'

        with self.app.test_client() as cli:
            rv = cli.post('/classifier',
                          data={
                              'label': new_label,
                              'lock_label': lock_clfr_str,
                              'bytes_b64': enc_data,
                          })
            self.assertStatus(rv, 201)
            resp_data = json.loads(rv.data.decode())
            self.assertEqual(resp_data["message"],
                             "Uploaded classifier for label '%s'." % new_label)
            self.assertEqual(resp_data["label"], new_label)

            rv = cli.get('/classifier_labels')
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Classifier labels.")
            self.assertSetEqual(set(resp_data['labels']),
                                {old_label, new_label})

            rv = cli.delete('/classifier', data={
                'label': new_label,
            })
            self.assertStatus(rv, 405)
            resp_data = json.loads(rv.data.decode())
            self.assertEqual(
                resp_data["message"],
                "Label '%s' refers to a classifier that is"
                " immutable." % new_label)
            self.assertEqual(resp_data['label'], new_label)

            rv = cli.get('/classifier_labels')
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Classifier labels.")
            self.assertSetEqual(set(resp_data['labels']),
                                {old_label, new_label})

    def test_post_classifier_failures(self):
        pickle_data = pickle.dumps(DummyClassifier.from_config({}))
        enc_data = base64.b64encode(pickle_data)
        bad_data = base64.b64encode(pickle.dumps(object()))
        old_label = 'dummy'
        new_label = 'dummy2'
        lock_clfr_str = '['

        with self.app.test_client() as cli:
            rv = cli.post('/classifier',
                          data={
                              'label': old_label,
                              'bytes_b64': enc_data,
                          })
            self.assertStatus(rv, 400)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(
                resp_data, "Label '%s' already exists in classifier"
                " collection." % old_label)
            self.assertEqual(resp_data['label'], old_label)

            rv = cli.post('/classifier', data={'label': old_label})
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "No state base64 data provided.")

            rv = cli.post('/classifier', data={'bytes_b64': enc_data})
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "No descriptive label provided.")

            rv = cli.post('/classifier',
                          data={
                              'label': old_label,
                              'lock_label': lock_clfr_str,
                              'bytes_b64': enc_data,
                          })
            self.assertStatus(rv, 400)
            self.assertMessage(
                json.loads(rv.data.decode()),
                "Invalid boolean value for 'lock_label'."
                " Was given: '%s'" % lock_clfr_str)

            rv = cli.post('/classifier',
                          data={
                              'label': new_label,
                              'bytes_b64': bad_data,
                          })
            self.assertStatus(rv, 400)
            self.assertMessage(
                json.loads(rv.data.decode()),
                "Data added for label '%s' is not a"
                " Classifier." % new_label)

    def test_del_classifier_failures(self):
        old_label = 'dummy'
        new_label = 'dummy2'

        with self.app.test_client() as cli:
            rv = cli.delete('/classifier', data={})
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "No label provided.")

            rv = cli.delete('/classifier', data={'label': old_label})
            self.assertStatus(rv, 405)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(
                resp_data, "Label '%s' refers to a classifier that is"
                " immutable." % old_label)
            self.assertEqual(resp_data['label'], old_label)

            rv = cli.delete('/classifier', data={'label': new_label})
            self.assertStatus(rv, 404)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(
                resp_data, "Label '%s' does not refer to a classifier"
                " currently registered." % new_label)
            self.assertEqual(resp_data['label'], new_label)

    def test_get_classifier_failures(self):
        label = 'dummy2'

        with self.app.test_client() as cli:
            rv = cli.get('/classifier', data={})
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "No label provided.")

            rv = cli.get('/classifier', data={'label': label})
            self.assertStatus(rv, 404)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(
                resp_data, "Label '%s' does not refer to a classifier"
                " currently registered." % label)
            self.assertEqual(resp_data['label'], label)

    def test_classify_failures(self):
        content_type = 'text/plain'
        bytes_b64 = base64.b64encode(b'TEST ELEMENT').decode()
        label_invalid_json_failure = '['
        label_valid_json_failure = '{}'
        label_valid_json_list_failure = '["test", {}]'
        missing_clfrs_1 = ['dummy', 'foo']
        missing_clfrs_2 = ['dummy', 'foo', 'bar']
        missing_clfrs_3 = ['foo']
        missing_clfrs_4 = ['foo', 'bar']

        with self.app.test_client() as cli:
            rv = cli.post('/classify', data={
                'content_type': content_type,
            })
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "No base-64 bytes provided.")

            rv = cli.post('/classify', data={
                'bytes_b64': bytes_b64,
            })
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "No content type provided.")

            rv = cli.post('/classify',
                          data={
                              'content_type': content_type,
                              'bytes_b64': bytes_b64,
                              'label': label_invalid_json_failure,
                          })
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "Label(s) are not properly formatted JSON.")

            rv = cli.post('/classify',
                          data={
                              'content_type': content_type,
                              'bytes_b64': bytes_b64,
                              'label': label_valid_json_failure,
                          })
            self.assertStatus(rv, 400)
            self.assertMessage(
                json.loads(rv.data.decode()),
                "Label must be a list of strings or a single"
                " string.")

            rv = cli.post('/classify',
                          data={
                              'content_type': content_type,
                              'bytes_b64': bytes_b64,
                              'label': label_valid_json_list_failure,
                          })
            self.assertStatus(rv, 400)
            self.assertMessage(
                json.loads(rv.data.decode()),
                "Label must be a list of strings or a single"
                " string.")

            rv = cli.post('/classify',
                          data={
                              'content_type': content_type,
                              'bytes_b64': bytes_b64,
                              'label': json.dumps(missing_clfrs_1),
                          })
            self.assertStatus(rv, 404)
            resp_data = json.loads(rv.data.decode())
            self.assert_(resp_data['message'].startswith(
                "The following labels are not registered with any"
                " classifiers:"))
            self.assertSetEqual(set(resp_data['missing_labels']),
                                set(missing_clfrs_1) - {'dummy'})

            rv = cli.post('/classify',
                          data={
                              'content_type': content_type,
                              'bytes_b64': bytes_b64,
                              'label': json.dumps(missing_clfrs_2),
                          })
            self.assertStatus(rv, 404)
            resp_data = json.loads(rv.data.decode())
            self.assert_(resp_data['message'].startswith(
                "The following labels are not registered with any"
                " classifiers:"))
            self.assertSetEqual(set(resp_data['missing_labels']),
                                set(missing_clfrs_2) - {'dummy'})

            rv = cli.post('/classify',
                          data={
                              'content_type': content_type,
                              'bytes_b64': bytes_b64,
                              'label': json.dumps(missing_clfrs_3),
                          })
            self.assertStatus(rv, 404)
            resp_data = json.loads(rv.data.decode())
            self.assert_(resp_data['message'].startswith(
                "The following labels are not registered with any"
                " classifiers:"))
            self.assertSetEqual(set(resp_data['missing_labels']),
                                set(missing_clfrs_3))

            rv = cli.post('/classify',
                          data={
                              'content_type': content_type,
                              'bytes_b64': bytes_b64,
                              'label': json.dumps(missing_clfrs_4),
                          })
            self.assertStatus(rv, 404)
            resp_data = json.loads(rv.data.decode())
            self.assert_(resp_data['message'].startswith(
                "The following labels are not registered with any"
                " classifiers:"))
            self.assertSetEqual(set(resp_data['missing_labels']),
                                set(missing_clfrs_4))

    def test_get_classifier_metadata_no_label(self):
        with self.app.test_client() as cli:
            #: :type: flask.wrappers.Response
            r = cli.get('/classifier_metadata')
            r_json = json.loads(r.data.decode())
            self.assertStatus(r, 400)
            self.assertMessage(r_json, "No label provided.")

    def test_get_classifier_metadata_invalid_label(self):
        with self.app.test_client() as cli:
            args = dict(label="no-valid-label")
            r = cli.get('/classifier_metadata', query_string=args)
            r_json = json.loads(r.data.decode())
            self.assertStatus(r, 404)
            self.assertMessage(
                r_json, "Label 'no-valid-label' does not refer "
                "to a classifier currently registered.")

    def test_get_classifier_labels_mocked(self):
        """
        Test that we can request the registered dummy classifiers class labels.
        Using mock objects to assert calls made.
        """
        expected_label = 'this-test-label'
        expected_class_labels = ['foo', 'bar', 'shazam']

        mock_classifier = mock.Mock(spec=Classifier)
        mock_classifier.get_labels = \
            mock.Mock(return_value=expected_class_labels)

        self.app.classifier_collection.labels = mock.Mock(
            return_value={expected_label})
        self.app.classifier_collection.get_classifier = mock.Mock(
            return_value=mock_classifier)

        with self.app.test_client() as cli:
            args = dict(label=expected_label)
            r = cli.get('/classifier_metadata', query_string=args)

            self.app.classifier_collection.labels.assert_called_once_with()
            self.app.classifier_collection.get_classifier\
                    .assert_called_once_with(expected_label)
            mock_classifier.get_labels.assert_called_once_with()

            r_json = json.loads(r.data.decode())
            self.assertStatus(r, 200)
            self.assertMessage(r_json, "Success")
            self.assertIn('class_labels', r_json)
            self.assertListEqual(r_json['class_labels'], expected_class_labels)

    def test_get_classifier_labels(self):
        """
        Test that we can request the registered dummy classifiers class labels.
        """
        with self.app.test_client() as cli:
            args = dict(label=self.dummy_label)
            r = cli.get('/classifier_metadata', query_string=args)
            r_json = json.loads(r.data.decode())
            self.assertStatus(r, 200)
            self.assertMessage(r_json, "Success")
            self.assertIn('class_labels', r_json)
            self.assertSetEqual(set(r_json['class_labels']),
                                {'negative', 'positive'})

    def test_add_iqr_state_classifier_param_failures(self):
        test_bytes = b"some not used bytes"
        test_bytes_b64 = base64.b64encode(test_bytes).decode()
        test_label = "classifier-test-label"

        with self.app.test_client() as cli:
            # Missing Bytes
            rv = cli.post('/iqr_classifier')
            self.assertStatus(rv, 400)
            self.assertResponseMessageRegex(rv,
                                            "No state base64 data provided.")

            # Missing label for classifier
            rv = cli.post('/iqr_classifier',
                          data={
                              "bytes_b64": test_bytes_b64,
                          })
            self.assertStatus(rv, 400)
            self.assertResponseMessageRegex(rv,
                                            "No descriptive label provided.")

            # Invalid lock flag value (not a boolean)
            rv = cli.post('/iqr_classifier',
                          data={
                              'bytes_b64': test_bytes_b64,
                              'label': test_label,
                              'lock_label': 'not-bool-convertible'
                          })
            self.assertStatus(rv, 400)
            self.assertResponseMessageRegex(
                rv, "Invalid boolean value for 'lock_label'. Was given: ")

    def test_add_iqr_state_classifier_existing_label(self):
        test_label = 'duplicate-label'
        test_new_cfier_b64 = base64.b64encode(b"some not used bytes")

        self.app.classifier_collection.add_classifier(
            test_label, DummySupervisedClassifier())

        with self.app.test_client() as cli:
            rv = cli.post('/iqr_classifier',
                          data={
                              'bytes_b64': test_new_cfier_b64,
                              'label': test_label,
                          })
            self.assertStatus(rv, 400)
            self.assertResponseMessageRegex(
                rv, "Label already exists in classifier collection.")

    @mock.patch.dict(
        os.environ, {
            Pluggable.PLUGIN_ENV_VAR:
            OS_ENV_PATH_SEP.join([
                STUB_CLASSIFIER_MOD_PATH,
                'tests.web.classifier_service.dummy_descriptor_generator',
            ])
        })
    def test_add_iqr_state_classifier_simple(self):
        """
        Test calling IQR classifier add endpoint with a simple IQR Session
        serialization.
        """
        # Make a simple session with dummy adjudication descriptor elements
        iqrs = IqrSession(session_uid=str("0"))
        iqr_p1 = DescriptorMemoryElement('test', 0).set_vector([0])
        iqr_n1 = DescriptorMemoryElement('test', 1).set_vector([1])
        iqrs.adjudicate(new_positives=[iqr_p1], new_negatives=[iqr_n1])

        test_iqrs_b64 = base64.b64encode(iqrs.get_state_bytes())
        test_label = 'test-label-08976azsdv'

        with mock.patch(STUB_CLASSIFIER_MOD_PATH +
                        ".DummySupervisedClassifier._train") as m_cfier_train:

            with self.app.test_client() as cli:
                rv = cli.post('/iqr_classifier',
                              data={
                                  'bytes_b64': test_iqrs_b64,
                                  'label': test_label,
                              })
                self.assertStatus(rv, 201)
                self.assertResponseMessageRegex(
                    rv, "Finished training "
                    "IQR-session-based "
                    "classifier for label "
                    "'%s'." % test_label)

            m_cfier_train.assert_called_once_with({
                'positive': {iqr_p1},
                'negative': {iqr_n1}
            })
            # Collection should include initial dummy classifier and new iqr
            # classifier.
            self.assertEqual(len(self.app.classifier_collection.labels()), 2)
            self.assertIn(test_label, self.app.classifier_collection.labels())
Ejemplo n.º 4
0
class TestDiscoveryViaEnvVar:
    """
    Unit tests for environment variable based discovery of types.
    """

    VAR_NAME = "SOME_SILLY_NAME"

    @mock.patch.dict(os.environ)
    def test_not_set(self) -> None:
        """
        Test when variable of the given name is not set in the environment.
        Nothing should be returned from the discovery.
        """
        # Delete the name in the env if it for some reason really existed...
        # We're in a mock.patch.dict so it's OK to do this with out affecting
        #   parent scope.
        if TestDiscoveryViaEnvVar.VAR_NAME in os.environ:
            del os.environ[TestDiscoveryViaEnvVar.VAR_NAME]
        type_set = discover_via_env_var(TestDiscoveryViaEnvVar.VAR_NAME)
        assert len(type_set) == 0

    @mock.patch.dict(os.environ,
                     {VAR_NAME: "tests.utils.test_plugin_dir.module_of_stuff"})
    def test_module_in_path(self) -> None:
        """
        Test that when a module path exist in the environment variable, it's
        searched for types.
        """
        assert TestDiscoveryViaEnvVar.VAR_NAME in os.environ
        type_set = discover_via_env_var(TestDiscoveryViaEnvVar.VAR_NAME)
        assert type_set == TYPES_IN_STUFF_MODULE

    @mock.patch.dict(
        os.environ, {
            VAR_NAME:
            OS_ENV_PATH_SEP.join([
                "tests.utils.test_plugin_dir.module_of_stuff",
                "tests.utils.test_plugin_dir.module_of_more_stuff",
            ])
        })
    def test_multiple_in_path(self) -> None:
        """
        Test that multiple modules resolve in environment variable.
        """
        test_set = discover_via_env_var(TestDiscoveryViaEnvVar.VAR_NAME)
        assert test_set == (TYPES_IN_STUFF_MODULE | TYPES_IN_MORE_STUFF_MODULE)

    @mock.patch.dict(os.environ, {VAR_NAME: "probably.not.a.valid.path"})
    def test_invalid_in_path(self) -> None:
        """
        Test when there is an invalid module in the path.
        """
        with pytest.raises(ModuleNotFoundError):
            discover_via_env_var(TestDiscoveryViaEnvVar.VAR_NAME)

    @mock.patch.dict(
        os.environ,
        {
            VAR_NAME:
            OS_ENV_PATH_SEP.join([
                "tests.utils.test_plugin_dir.module_of_stuff",
                "tests.utils.test_plugin_dir.module_of_more_stuff",
                # The exception causing module.
                "tests.utils.test_plugin_dir.module_with_exception"
            ])
        })
    def test_module_exception(self) -> None:
        """
        Test when there is a module in the path that raises an exception.
        """
        with pytest.raises(RuntimeError, match=r"^Expected error on import$"):
            discover_via_env_var(TestDiscoveryViaEnvVar.VAR_NAME)
Ejemplo n.º 5
0
class TestClassifierService (unittest.TestCase):

    # noinspection PyUnresolvedReferences
    @mock.patch.dict(os.environ, {
        Pluggable.PLUGIN_ENV_VAR:
            OS_ENV_PATH_SEP.join([
                STUB_CLASSIFIER_MOD_PATH,
                'tests.web.classifier_service.dummy_descriptor_generator',
            ])
    })
    def setUp(self):
        super(TestClassifierService, self).setUp()
        self.config = SmqtkClassifierService.get_default_config()

        key_mce = "smqtk.representation.classification_element.memory.MemoryClassificationElement"
        key_c_dummy = "tests.web.classifier_service.dummy_classifier.DummyClassifier"
        key_dme = "smqtk.representation.descriptor_element.local_elements.DescriptorMemoryElement"
        key_dg_dummy = "tests.web.classifier_service.dummy_descriptor_generator.DummyDescriptorGenerator"
        key_sc_dummy = "tests.web.classifier_service.dummy_classifier.DummySupervisedClassifier"

        self.config['classification_factory']['type'] = key_mce
        # del self.config['classification_factory'][
        #     'smqtk.representation.classification_element.file'
        #     '.FileClassificationElement'
        # ]

        del self.config['classifier_collection']['__example_label__']
        self.dummy_label = 'dummy'
        self.config['classifier_collection'][self.dummy_label] = {
            key_c_dummy: {},
            'type': key_c_dummy
        }
        self.config['immutable_labels'] = [self.dummy_label]

        self.config['descriptor_factory']['type'] = key_dme
        # del self.config['descriptor_factory']['DescriptorFileElement']

        self.config['descriptor_generator'] = {
            key_dg_dummy: {},
            'type': key_dg_dummy
        }

        self.config['iqr_state_classifier_config']['type'] = key_sc_dummy

        self.config['enable_classifier_removal'] = True

        self.config['flask_app'] = {}
        del self.config['server']

        self.app = SmqtkClassifierService(json_config=self.config)

    def assertStatus(self, rv, code):
        self.assertEqual(rv.status_code, code)

    def assertResponseMessageRegex(self, rv, regex):
        self.assertRegex(json.loads(rv.data.decode())['message'], regex)

    def assertMessage(self, resp_data, message):
        self.assertEqual(message, resp_data['message'])

    def test_server_alive(self):
        rv = self.app.test_client().get('/is_ready')

        self.assertStatus(rv, 200)
        resp_data = json.loads(rv.data.decode())
        self.assertMessage(resp_data, "Yes, I'm alive!")

    def test_preconfigured_labels(self):
        rv = self.app.test_client().get('/classifier_labels')

        self.assertStatus(rv, 200)
        resp_data = json.loads(rv.data.decode())
        self.assertMessage(resp_data, "Classifier labels.")
        self.assertListEqual(resp_data['labels'], ['dummy'])

    def test_one_classify(self):
        results_exp = dict(positive=0.5, negative=0.5)
        label = 'dummy'
        content_type = 'text/plain'
        element = base64.b64encode(b'TEST ELEMENT').decode()

        with self.app.test_client() as cli:
            rv = cli.post('/classify', data={
                'content_type': content_type,
                'bytes_b64': element,
            })
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Finished classification.")
            self.assertDictEqual(resp_data['result'][label], results_exp)

            # Same classifier, just retrieving it different ways, so skip the
            # correctness check
            rv = cli.post('/classify', data={
                'label': label,
                'content_type': content_type,
                'bytes_b64': element,
            })
            self.assertStatus(rv, 200)

            rv = cli.post('/classify', data={
                'label': json.dumps(label),
                'content_type': content_type,
                'bytes_b64': element,
            })
            self.assertStatus(rv, 200)

            rv = cli.post('/classify', data={
                'label': json.dumps([label]),
                'content_type': content_type,
                'bytes_b64': element,
            })
            self.assertStatus(rv, 200)

    def test_adjusted_classify(self):
        results_exp = dict(positive=0.5, negative=0.5)
        label = 'dummy'
        content_type = 'text/plain'
        element = base64.b64encode(b'TEST ELEMENT').decode()

        with self.app.test_client() as cli:
            rv = cli.post('/classify', data={
                'content_type': content_type,
                'bytes_b64': element,
                'adjustment': json.dumps({
                    'positive': 0,
                }),
            })
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Finished classification.")
            self.assertDictEqual(resp_data['result'][label], results_exp)

            rv = cli.post('/classify', data={
                'content_type': content_type,
                'bytes_b64': element,
                'adjustment': json.dumps({
                    'positive': -1,
                }),
            })
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Finished classification.")
            result = resp_data['result'][label]
            self.assertAlmostEqual(result['positive'], 1/(1+math.exp(-1)))
            self.assertAlmostEqual(result['negative'], 1/(1+math.exp(1)))

            rv = cli.post('/classify', data={
                'content_type': content_type,
                'bytes_b64': element,
                'adjustment': json.dumps({
                    'positive': 1,
                }),
            })
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Finished classification.")
            result = resp_data['result'][label]
            self.assertAlmostEqual(result['positive'], 1/(1+math.exp(1)))
            self.assertAlmostEqual(result['negative'], 1/(1+math.exp(-1)))

    def test_multiple_classify(self):
        content_type = 'text/plain'
        element = base64.b64encode(b'TEST ELEMENT').decode()
        results_exp = dict(positive=0.5, negative=0.5)
        pickle_data = pickle.dumps(DummyClassifier.from_config({}))
        enc_data = base64.b64encode(pickle_data)
        old_label = 'dummy'
        new_label = 'dummy2'
        lock_clfr_str = 'true'

        with self.app.test_client() as cli:
            rv = cli.post('/classifier', data={
                'label': new_label,
                'lock_label': lock_clfr_str,
                'bytes_b64': enc_data,
            })
            self.assertStatus(rv, 201)
            resp_data = json.loads(rv.data.decode())
            self.assertEqual(resp_data["message"],
                             "Uploaded classifier for label '%s'."
                             % new_label)
            self.assertEqual(resp_data["label"], new_label)

            rv = cli.post('/classify', data={
                'label': new_label,
                'content_type': content_type,
                'bytes_b64': element,
            })
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Finished classification.")
            self.assertDictEqual(resp_data['result'][new_label], results_exp)

            rv = cli.post('/classify', data={
                'content_type': content_type,
                'bytes_b64': element,
            })
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Finished classification.")
            self.assertDictEqual(resp_data['result'][old_label], results_exp)
            self.assertDictEqual(resp_data['result'][new_label], results_exp)

    def test_get_add_del_classifier(self):
        old_label = 'dummy'
        new_label = 'dummy2'

        with self.app.test_client() as cli:
            rv = cli.get(f'/classifier?label={old_label}')
            self.assertStatus(rv, 200)
            enc_data = rv.data.decode()

            rv = cli.post('/classifier', data={
                'label': new_label,
                'bytes_b64': enc_data,
            })
            self.assertStatus(rv, 201)
            resp_data = json.loads(rv.data.decode())
            self.assertEqual(resp_data["message"],
                             "Uploaded classifier for label '%s'."
                             % new_label)
            self.assertEqual(resp_data["label"], new_label)

            rv = cli.get('/classifier_labels')
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Classifier labels.")
            self.assertSetEqual(set(resp_data['labels']),
                                {old_label, new_label})

            rv = cli.delete('/classifier', data={
                'label': new_label,
            })
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertEqual(resp_data["message"],
                             "Removed classifier with label '%s'."
                             % new_label)
            self.assertEqual(resp_data["removed_label"], new_label)

            rv = cli.get('/classifier_labels')
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Classifier labels.")
            self.assertSetEqual(set(resp_data['labels']), {old_label})

    def test_add_imm_del_classifier(self):
        pickle_data = pickle.dumps(DummyClassifier.from_config({}))
        enc_data = base64.b64encode(pickle_data).decode('utf8')
        old_label = 'dummy'
        new_label = 'dummy2'
        lock_clfr_str = 'true'

        with self.app.test_client() as cli:
            rv = cli.post('/classifier', data={
                'label': new_label,
                'lock_label': lock_clfr_str,
                'bytes_b64': enc_data,
            })
            self.assertStatus(rv, 201)
            resp_data = json.loads(rv.data.decode())
            self.assertEqual(resp_data["message"],
                             "Uploaded classifier for label '%s'."
                             % new_label)
            self.assertEqual(resp_data["label"], new_label)

            rv = cli.get('/classifier_labels')
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Classifier labels.")
            self.assertSetEqual(set(resp_data['labels']),
                                {old_label, new_label})

            rv = cli.delete('/classifier', data={
                'label': new_label,
            })
            self.assertStatus(rv, 405)
            resp_data = json.loads(rv.data.decode())
            self.assertEqual(resp_data["message"],
                             "Label '%s' refers to a classifier that is"
                             " immutable." % new_label)
            self.assertEqual(resp_data['label'], new_label)

            rv = cli.get('/classifier_labels')
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data, "Classifier labels.")
            self.assertSetEqual(set(resp_data['labels']),
                                {old_label, new_label})

    def test_post_classifier_failures(self):
        pickle_data = pickle.dumps(DummyClassifier.from_config({}))
        enc_data = base64.b64encode(pickle_data)
        bad_data = base64.b64encode(pickle.dumps(object()))
        old_label = 'dummy'
        new_label = 'dummy2'
        lock_clfr_str = '['

        with self.app.test_client() as cli:
            rv = cli.post('/classifier', data={
                'label': old_label,
                'bytes_b64': enc_data,
            })
            self.assertStatus(rv, 400)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data,
                               "Label '%s' already exists in classifier"
                               " collection." % old_label)
            self.assertEqual(resp_data['label'], old_label)

            rv = cli.post('/classifier', data={'label': old_label})
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "No state base64 data provided.")

            rv = cli.post('/classifier', data={'bytes_b64': enc_data})
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "No descriptive label provided.")

            rv = cli.post('/classifier', data={
                'label': old_label,
                'lock_label': lock_clfr_str,
                'bytes_b64': enc_data,
            })
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "Invalid boolean value for 'lock_label'."
                               " Was given: '%s'" % lock_clfr_str)

            rv = cli.post('/classifier', data={
                'label': new_label,
                'bytes_b64': bad_data,
            })
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "Data added for label '%s' is not a"
                               " Classifier." % new_label)

    def test_del_classifier_failures(self):
        old_label = 'dummy'
        new_label = 'dummy2'

        with self.app.test_client() as cli:
            rv = cli.delete('/classifier', data={})
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "No label provided.")

            rv = cli.delete('/classifier', data={'label': old_label})
            self.assertStatus(rv, 405)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data,
                               "Label '%s' refers to a classifier that is"
                               " immutable." % old_label)
            self.assertEqual(resp_data['label'], old_label)

            rv = cli.delete('/classifier', data={'label': new_label})
            self.assertStatus(rv, 404)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data,
                               "Label '%s' does not refer to a classifier"
                               " currently registered." % new_label)
            self.assertEqual(resp_data['label'], new_label)

    def test_get_classifier_failures(self):
        label = 'dummy2'

        with self.app.test_client() as cli:
            rv = cli.get('/classifier', data={})
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "No label provided.")

            rv = cli.get(f'/classifier?label={label}')
            self.assertStatus(rv, 404)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data,
                               "Label '%s' does not refer to a classifier"
                               " currently registered." % label)
            self.assertEqual(resp_data['label'], label)

    def test_classify_uids_no_classifiers(self):
        """ Test that an empty result response comes back when no classifiers
        are registered in the service. """
        # Empty out the classifier collection for the purpose of this test.
        for la in self.app.classifier_collection.labels():
            self.app.classifier_collection.remove_classifier(la)

        with self.app.test_client() as cli:
            rv = cli.post("classify_uids", data=dict(
                uid_list=[0, 1, 2],  # just some meaningless UIDs
            ))
            self.assertStatus(rv, 200)
            resp_data = json.loads(rv.data)
            self.assertMessage(resp_data,
                               "No classifiers currently loaded.")
            assert resp_data['result'] == {}

    def test_classify_uids_no_uid_list(self):
        """ Test the appropriate error is returned when no UID list is
        provided."""
        with self.app.test_client() as cli:
            #: :type: requests.Response
            rv = cli.post("classify_uids")
            self.assertStatus(rv, 400)
            # noinspection PyTypeChecker
            self.assertMessage(rv.json, "No UIDs provided.")

    def test_classify_uids_empty_uid_list(self):
        """ Test the appropriate error is returned when an empty list is
        provided. """
        with self.app.test_client() as cli:
            #: :type: requests.Response
            rv = cli.post("classify_uids", data=dict(
                uid_list='[]'
            ))
            self.assertStatus(rv, 400)
            # noinspection PyTypeChecker
            self.assertMessage(rv.json, "No UIDs provided.")

    def test_classify_uids_bad_uid_json(self):
        """ Test that the appropriate error is returned when invalid JSON is
        provided. """
        uid_list_json = 'not json'
        with self.app.test_client() as cli:
            #: :type: requests.Response
            rv = cli.post("classify_uids", data=dict(
                uid_list=uid_list_json
            ))
            self.assertStatus(rv, 400)
            # noinspection PyTypeChecker
            self.assertMessage(rv.json, "Failed to parse JSON list of UIDs.")

    def test_classify_uids_invalid_label_json(self):
        """ Test providing labels but with a value that is invalid json and
        that the appropriate error occurs. """
        with self.app.test_client() as cli:
            #: :type: requests.Response
            rv = cli.post("classify_uids", data=dict(
                uid_list=json.dumps(['a', 'b', 'c']),
                label="[",
            ))
            self.assertStatus(rv, 400)
            # noinspection PyTypeChecker
            self.assertMessage(rv.json,
                               "Invalid label(s) specified: "
                               "Label is not a properly formatted JSON nor a "
                               "simple string.")

    def test_classify_uids_invalid_label_json_value(self):
        """ Test providing labels that is a valid json but not an acceptable
        type and that the appropriate error occurs. """
        with self.app.test_client() as cli:
            #: :type: requests.Response
            rv = cli.post("classify_uids", data=dict(
                uid_list=json.dumps(['a', 'b', 'c']),
                label="{}",
            ))
            self.assertStatus(rv, 400)
            # noinspection PyTypeChecker
            self.assertMessage(rv.json,
                               "Invalid label(s) specified: "
                               "Label must be a list of strings or a single "
                               "string (given type: dict).")

    def test_classify_uids_invalid_label_json_inner_value(self):
        """ Test providing a valid json string but with an inner value that is
        not a string and that the appropriate error is returned. """
        with self.app.test_client() as cli:
            #: :type: requests.Response
            rv = cli.post("classify_uids", data=dict(
                uid_list=json.dumps(['a', 'b', 'c']),
                label='["label", []]',
            ))
            self.assertStatus(rv, 400)
            # noinspection PyTypeChecker
            self.assertMessage(rv.json,
                               "Invalid label(s) specified: "
                               "Label must be a list of strings or a single "
                               "string: give a list of more than just strings "
                               "(found type: list).")

    def test_classify_uids_missing_labels(self):
        """ Test providing a label for a classifier that is not present in the
        service and that the appropriate error is returned. This check is
        expected to occur before attempting descriptor retrieval or
        classification. """
        missing_clfrs = ['not-present', 'dummy']
        with self.app.test_client() as cli:
            #: :type: requests.Response
            rv = cli.post("classify_uids", data=dict(
                uid_list=json.dumps(['a', 'b', 'c']),
                label=json.dumps(missing_clfrs),
            ))
            self.assertStatus(rv, 404)
            rv_json = rv.json
            # noinspection PyUnresolvedReferences
            assert rv_json['message'].startswith(
                "The following labels are not registered with any "
                "classifiers: "
            )
            # noinspection PyUnresolvedReferences
            assert set(rv_json['missing_labels']) == \
                (set(missing_clfrs) - {"dummy"})

    def test_classify_uids_missing_uids(self):
        """ Test that the appropriate error occurs when requesting descriptor
        UIDs that are not present in the configured descriptor set. """
        # Default setup construction specifies a default, empty descriptor set
        # so any UIDs will be "missing".
        with self.app.test_client() as cli:
            #: :type: requests.Response
            rv = cli.post("classify_uids", data=dict(
                uid_list=json.dumps(['a', 'b', 'c']),
            ))
            self.assertStatus(rv, 400)
            # noinspection PyTypeChecker
            self.assertMessage(rv.json, "One or more input UIDs did not exist "
                                        "in the configured descriptor set!")

    def test_classify_uids(self):
        """ Test a simple invocation. """
        # Add a single descriptor to the default empty DescriptorSet.
        self.app.descriptor_set.add_descriptor(
            DescriptorMemoryElement('test', 0).set_vector([0]))
        with self.app.test_client() as cli:
            #: :type: requests.Response
            rv = cli.post("classify_uids", data=dict(
                uid_list=json.dumps([0]),
            ))
            self.assertStatus(rv, 200)
            rv_json = rv.json
            # DummyClassifier impl is known to just return 50/50
            # classifications.
            # noinspection PyUnresolvedReferences
            assert rv_json['result'] == {
                "dummy": [{"positive": 0.5, "negative": 0.5}]
            }

    def test_classify_failures(self):
        content_type = 'text/plain'
        bytes_b64 = base64.b64encode(b'TEST ELEMENT').decode()
        label_invalid_json_failure = '['
        label_valid_json_failure = '{}'
        label_valid_json_list_failure = '["test", {}]'
        missing_clfrs_1 = ['dummy', 'foo']
        missing_clfrs_2 = ['dummy', 'foo', 'bar']
        missing_clfrs_3 = ['foo']
        missing_clfrs_4 = ['foo', 'bar']

        with self.app.test_client() as cli:
            # When we provide no base-64 data
            rv = cli.post('/classify', data={
                'content_type': content_type,
            })
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "No base-64 bytes provided.")

            # When we provide no content type for the data provided.
            rv = cli.post('/classify', data={
                'bytes_b64': bytes_b64,
            })
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "No content type provided.")

            # When we provide invalid JSON for the labels.
            rv = cli.post('/classify', data={
                'content_type': content_type,
                'bytes_b64': bytes_b64,
                'label': label_invalid_json_failure,
            })
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "Invalid label(s) specified: "
                               "Label is not a properly formatted JSON nor a "
                               "simple string.")

            # When we provide valid json, but neither a string nor list
            rv = cli.post('/classify', data={
                'content_type': content_type,
                'bytes_b64': bytes_b64,
                'label': label_valid_json_failure,
            })
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "Invalid label(s) specified: "
                               "Label must be a list of strings or a single "
                               "string (given type: dict).")

            # When one value of a list of labels is not a string.
            rv = cli.post('/classify', data={
                'content_type': content_type,
                'bytes_b64': bytes_b64,
                'label': label_valid_json_list_failure,
            })
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "Invalid label(s) specified: "
                               "Label must be a list of strings or a "
                               "single string: give a list of more "
                               "than just strings (found type: dict).")

            # When providing labels for classifiers not present in the service.
            # 4 variations.
            rv = cli.post('/classify', data={
                'content_type': content_type,
                'bytes_b64': bytes_b64,
                'label': json.dumps(missing_clfrs_1),
            })
            self.assertStatus(rv, 404)
            resp_data = json.loads(rv.data.decode())
            self.assertTrue(resp_data['message'].startswith(
                "The following labels are not registered with any"
                " classifiers:"))
            self.assertSetEqual(set(resp_data['missing_labels']),
                                set(missing_clfrs_1) - {'dummy'})

            rv = cli.post('/classify', data={
                'content_type': content_type,
                'bytes_b64': bytes_b64,
                'label': json.dumps(missing_clfrs_2),
            })
            self.assertStatus(rv, 404)
            resp_data = json.loads(rv.data.decode())
            self.assertTrue(resp_data['message'].startswith(
                "The following labels are not registered with any"
                " classifiers:"))
            self.assertSetEqual(set(resp_data['missing_labels']),
                                set(missing_clfrs_2) - {'dummy'})

            rv = cli.post('/classify', data={
                'content_type': content_type,
                'bytes_b64': bytes_b64,
                'label': json.dumps(missing_clfrs_3),
            })
            self.assertStatus(rv, 404)
            resp_data = json.loads(rv.data.decode())
            self.assertTrue(resp_data['message'].startswith(
                "The following labels are not registered with any"
                " classifiers:"))
            self.assertSetEqual(set(resp_data['missing_labels']),
                                set(missing_clfrs_3))

            rv = cli.post('/classify', data={
                'content_type': content_type,
                'bytes_b64': bytes_b64,
                'label': json.dumps(missing_clfrs_4),
            })
            self.assertStatus(rv, 404)
            resp_data = json.loads(rv.data.decode())
            self.assertTrue(resp_data['message'].startswith(
                "The following labels are not registered with any"
                " classifiers:"))
            self.assertSetEqual(set(resp_data['missing_labels']),
                                set(missing_clfrs_4))

    def test_get_classifier_metadata_no_label(self):
        with self.app.test_client() as cli:
            #: :type: flask.wrappers.Response
            r = cli.get('/classifier_metadata')
            r_json = json.loads(r.data.decode())
            self.assertStatus(r, 400)
            self.assertMessage(r_json, "No label provided.")

    def test_get_classifier_metadata_invalid_label(self):
        with self.app.test_client() as cli:
            args = dict(label="no-valid-label")
            r = cli.get('/classifier_metadata', query_string=args)
            r_json = json.loads(r.data.decode())
            self.assertStatus(r, 404)
            self.assertMessage(r_json, "Label 'no-valid-label' does not refer "
                                       "to a classifier currently registered.")

    def test_get_classifier_labels_mocked(self):
        """
        Test that we can request the registered dummy classifiers class labels.
        Using mock objects to assert calls made.
        """
        expected_label = 'this-test-label'
        expected_class_labels = ['foo', 'bar', 'shazam']

        mock_classifier = mock.Mock(spec=Classifier)
        mock_classifier.get_labels = \
            mock.Mock(return_value=expected_class_labels)

        self.app.classifier_collection.labels = mock.Mock(
            return_value={expected_label}
        )
        self.app.classifier_collection.get_classifier = mock.Mock(
            return_value=mock_classifier
        )

        with self.app.test_client() as cli:
            args = dict(label=expected_label)
            r = cli.get('/classifier_metadata', query_string=args)

            self.app.classifier_collection.labels.assert_called_once_with()
            self.app.classifier_collection.get_classifier\
                    .assert_called_once_with(expected_label)
            mock_classifier.get_labels.assert_called_once_with()

            r_json = json.loads(r.data.decode())
            self.assertStatus(r, 200)
            self.assertMessage(r_json, "Success")
            self.assertIn('class_labels', r_json)
            self.assertListEqual(r_json['class_labels'], expected_class_labels)

    def test_get_classifier_labels(self):
        """
        Test that we can request the registered dummy classifiers class labels.
        """
        with self.app.test_client() as cli:
            args = dict(label=self.dummy_label)
            r = cli.get('/classifier_metadata', query_string=args)
            r_json = json.loads(r.data.decode())
            self.assertStatus(r, 200)
            self.assertMessage(r_json, "Success")
            self.assertIn('class_labels', r_json)
            self.assertSetEqual(set(r_json['class_labels']),
                                {'negative', 'positive'})

    def test_add_iqr_state_classifier_param_failures(self):
        test_bytes = b"some not used bytes"
        test_bytes_b64 = base64.b64encode(test_bytes).decode()
        test_label = "classifier-test-label"

        with self.app.test_client() as cli:
            # Missing Bytes
            rv = cli.post('/iqr_classifier')
            self.assertStatus(rv, 400)
            self.assertResponseMessageRegex(
                rv, "No state base64 data provided."
            )

            # Missing label for classifier
            rv = cli.post('/iqr_classifier', data={
                "bytes_b64": test_bytes_b64,
            })
            self.assertStatus(rv, 400)
            self.assertResponseMessageRegex(
                rv, "No descriptive label provided."
            )

            # Invalid lock flag value (not a boolean)
            rv = cli.post('/iqr_classifier', data={
                'bytes_b64': test_bytes_b64,
                'label': test_label,
                'lock_label': 'not-bool-convertible'
            })
            self.assertStatus(rv, 400)
            self.assertResponseMessageRegex(
                rv, "Invalid boolean value for 'lock_label'. Was given: "
            )

    def test_add_iqr_state_classifier_existing_label(self):
        test_label = 'duplicate-label'
        test_new_cfier_b64 = base64.b64encode(
            b"some not used bytes"
        )

        self.app.classifier_collection.add_classifier(
            test_label, DummySupervisedClassifier()
        )

        with self.app.test_client() as cli:
            rv = cli.post('/iqr_classifier', data={
                'bytes_b64': test_new_cfier_b64,
                'label': test_label,
            })
            self.assertStatus(rv, 400)
            self.assertResponseMessageRegex(
                rv, "Label already exists in classifier collection."
            )

    @mock.patch.dict(os.environ, {
        Pluggable.PLUGIN_ENV_VAR:
            OS_ENV_PATH_SEP.join([
                STUB_CLASSIFIER_MOD_PATH,
                'tests.web.classifier_service.dummy_descriptor_generator',
            ])
    })
    def test_add_iqr_state_classifier_simple(self):
        """
        Test calling IQR classifier add endpoint with a simple IQR Session
        serialization.
        """
        # Make a simple session with dummy adjudication descriptor elements
        rank_relevancy = mock.MagicMock(spec=RankRelevancy)
        iqrs = IqrSession(rank_relevancy, session_uid=str("0"))
        iqr_p1 = DescriptorMemoryElement('test', 0).set_vector([0])
        iqr_n1 = DescriptorMemoryElement('test', 1).set_vector([1])
        iqrs.adjudicate(
            new_positives=[iqr_p1], new_negatives=[iqr_n1]
        )

        test_iqrs_b64 = base64.b64encode(iqrs.get_state_bytes())
        test_label = 'test-label-08976azsdv'

        with mock.patch(STUB_CLASSIFIER_MOD_PATH +
                        ".DummySupervisedClassifier._train") as m_cfier_train:

            with self.app.test_client() as cli:
                rv = cli.post('/iqr_classifier', data={
                    'bytes_b64': test_iqrs_b64,
                    'label': test_label,
                })
                self.assertStatus(rv, 201)
                self.assertResponseMessageRegex(rv, "Finished training "
                                                    "IQR-session-based "
                                                    "classifier for label "
                                                    "'%s'." % test_label)

            m_cfier_train.assert_called_once_with(
                {'positive': {iqr_p1}, 'negative': {iqr_n1}}
            )
            # Collection should include initial dummy classifier and new iqr
            # classifier.
            self.assertEqual(len(self.app.classifier_collection.labels()), 2)
            self.assertIn(test_label, self.app.classifier_collection.labels())
Ejemplo n.º 6
0
    assert 'ImplFoo' in class_dict
    assert 'ImplBar' in class_dict
    assert 'ImplDoExport' in class_dict
    assert 'ImplExternal3' in class_dict
    # Not expected to be picked up from external_2 module
    assert 'ImplExternal1' not in class_dict
    assert 'ImplExternal2' not in class_dict

    # Check that new classes function as expected
    assert class_dict['ImplFoo']().inst_method('a') == 'fooa'
    assert class_dict['ImplBar']().inst_method('b') == 'barb'
    assert class_dict['ImplDoExport']().inst_method('c') == 'doExportc'
    assert class_dict['ImplExternal3']().inst_method('d') == "external3d"


@mock.patch.dict(os.environ, {ENV_VAR: OS_ENV_PATH_SEP.join([EXT_MOD_1,
                                                             EXT_MOD_2])})
def test_external_1_and_2():
    """
    Test loading both external_1 and external_2 module subclasses.
    """
    class_set = get_plugins_for_class(DummyInterface)
    assert len(class_set) == 6

    class_dict = {t.__name__: t for t in class_set}

    # Classes we expect to be discovered
    assert 'ImplFoo' in class_dict
    assert 'ImplBar' in class_dict
    assert 'ImplDoExport' in class_dict
    assert 'ImplExternal1' in class_dict
    assert 'ImplExternal2' in class_dict