def test_missing_labels_file(self, fs: FakeFilesystem) -> None: fs.add_real_directory('./testdata/test_model/test_instance') fs.remove_object('./testdata/test_model/test_instance/label.vocab') with pytest.raises( Exception, match=(r'Failure to load labels file from {0} with exception'). format('./testdata/test_model/test_instance/label.vocab')): Classifier(self.BASE_CLASSIFIER_PATH, 'test_model')
def test_invalid_bert(self, fs: FakeFilesystem) -> None: bad_bert_path = './bad/path/to/bert' config = """ { "bert": "%s", "labels": "label.vocab", "is_released": true, "description": "This is the latest model from Sascha.", "metadata": { "thesaurus": "issues" } } """ % (bad_bert_path) fs.add_real_directory('./testdata/test_model/test_instance') fs.remove_object('./testdata/test_model/test_instance/config.json') fs.create_file('./testdata/test_model/test_instance/config.json', contents=config) with pytest.raises(Exception, match='SavedModel file does not exist at'): c = Classifier(self.BASE_CLASSIFIER_PATH, 'test_model') # Bad bert is only used on uncached embed. c.classify(['some string'])
def test_e2e(app: Flask, fs: FakeFilesystem) -> None: seq_pattern = ( 'react more swiftly to comply with international instruments %d') instance_path = './testdata/test_model/test_instance' fs.add_real_directory(instance_path) fs.remove_object('./testdata/test_model/test_instance/quality.json') fs.remove_object('./testdata/test_model/test_instance/thresholds.json') client = app.test_client() with app.test_request_context(): # initial classify returns no topics because we deleted thresholds.json. resp = client.post('/classify?model=test_model', data=json.dumps({'seqs': [seq_pattern % 1]}), content_type='application/json') assert resp.status == '200 OK' assert len(json.loads(resp.data)[0]) == 0 # now we add training labels assert client.put( '/classification_sample?model=test_model', data=json.dumps({ 'samples': [{ 'seq': seq_pattern % i, 'training_labels': [{ 'topic': 'International instruments' }] } for i in range(20)] }), content_type='application/json').status == '200 OK' resp = client.get('/classification_sample?model=test_model&seq=*') assert resp.status == '200 OK' assert len(json.loads(resp.data)) == 20 assert client.post('/task', data=json.dumps({ 'provider': 'RefreshThresholds', 'name': 'thres', 'model': 'test_model' }), content_type='application/json').status == '200 OK' wait_for_task(client, 'thres') for i in range(20): resp = client.post('/classify?model=test_model', data=json.dumps({'seqs': [seq_pattern % i]}), content_type='application/json') assert resp.status == '200 OK' assert len(json.loads(resp.data)[0]) == 1 assert client.post('/task', data=json.dumps({ 'provider': 'RefreshPredictions', 'name': 'pred', 'model': 'test_model' }), content_type='application/json').status == '200 OK' wait_for_task(client, 'pred') resp = client.get('/classification_sample?model=test_model&seq=*') assert resp.status == '200 OK' data = json.loads(resp.data) assert len(data) == 20 assert data[0]['predicted_labels'][0][ 'topic'] == 'International instruments' assert data[0]['predicted_labels'][0]['quality'] >= 0.5