class TestDevRemoteBSTClassifier(AsyncTestCase): """Test the RemoteBSTClassifier against the dev server""" def setUp(self): super().setUp() conf = get_conf() self.model_params = { "hiddenLayers": [5, 6, 7, 8], "learningRate": 0.4, } url = conf['aimetrics']['dev']['hosts']['ai']['base_url'] self.clf = RemoteBSTClassifier(url, "bnn", model_params=self.model_params) self.assertIsNotNone(self.clf) @gen_test def test_async_fit(self): yield self.clf.async_fit(train_X, train_y) # test params m = yield self.clf.get_model() self.assertEqual(2+ len(self.model_params['hiddenLayers']), len(m["model"]["layers"])) def tearDown(self): if self.clf.model_id: self.clf.destroy_model()
class TestMockRemoteBSTClassifier(AsyncHTTPTestCase): """Test the RemoteBSTClassifier against a mock server""" def get_app(self): app = Application([ (r'/classifier/create/(\w+)', MockAICreateHandler), (r'/classifier/(\w+)', MockAIObjectHandler), (r'/classifier/(\w+)/train', MockAITrainHandler), (r'/classifier/(\w+)/predict', MockAIPredictHandler), ]) self.clf = RemoteBSTClassifier(self.get_url('/'), "bnn") self.assertIsNotNone(self.clf) return app @gen_test def test_mock_create_model(self): clf_id = yield self.clf._create_model() self.assertEqual(clf_model_id, clf_id) @gen_test def test_mock_async_fit(self): yield self.clf.async_fit(train_X, train_y) self.assertIsNotNone(self.clf.training_error) def test_mock_fit(self): self.clf.fit(train_X, train_y)