def test_project_creation_using_locked_indices(self): """ There was a bug case that warrants this test case. In case any of the indices in Elasticsearch are locked for whatever reason, then project creation fails because it contains/contained lazy index creation which ignores existing indices and thus created FORBIDDEN_ACCESS errors in Elasticsearch. """ index_names = ["locked_index_{}".format(i) for i in range(1, 5)] for index in index_names: self.__create_locked_index(index) response = self.client.post( reverse(f"{VERSION_NAMESPACE}:project-list"), format="json", data={ "title": "faulty_project", "indices_write": index_names, "users_write": [self.admin.username] }) ec = ElasticCore() for index in index_names: ec.delete_index(index) self.assertTrue(response.status_code == status.HTTP_201_CREATED)
def tearDown(self) -> None: Tagger.objects.all().delete() ec = ElasticCore() res = ec.delete_index(self.test_index_copy) ec.delete_index(index=self.test_index_name, ignore=[400, 404]) print_output(f"Delete apply_taggers test index {self.test_index_copy}", res)
def tearDown(self) -> None: ec = ElasticCore() ec.delete_index(index=self.test_index_name, ignore=[400, 404]) print_output(f"Delete [Rakun Extractor] test index {self.test_index_name}", None) Embedding.objects.all().delete() ElasticCore().delete_index(index=self.test_index_name, ignore=[400, 404]) print_output(f"Delete Rakun FASTTEXT Embeddings", None)
def destroy(self, request, pk=None, **kwargs): with transaction.atomic(): index_name = Index.objects.get(pk=pk).name es = ElasticCore() es.delete_index(index_name) Index.objects.filter(pk=pk).delete() return Response( {"message": f"Deleted index {index_name} from Elasticsearch!"})
def bulk_delete(self, request, project_pk=None): serializer: IndexBulkDeleteSerializer = self.get_serializer( data=request.data) serializer.is_valid(raise_exception=True) # Initialize Elastic requirements. ec = ElasticCore() # Get the index names. ids = serializer.validated_data["ids"] objects = Index.objects.filter(pk__in=ids) index_names = [item.name for item in objects] # Ensure deletion on both Elastic and DB. if index_names: ec.delete_index(",".join(index_names)) deleted = objects.delete() info = {"num_deleted": deleted[0], "deleted_types": deleted[1]} return Response(info, status=status.HTTP_200_OK)
class IndexViewsTest(APITestCase): @classmethod def setUpTestData(cls): # Owner of the project cls.user = create_test_user('user', '*****@*****.**', 'pw', superuser=True) def setUp(self) -> None: self.client.login(username="******", password="******") self.ec = ElasticCore() self.ids = [] self.index_names = [ "test_for_index_endpoint_1", "test_for_index_endpoint_2" ] for index_name in self.index_names: index, is_created = Index.objects.get_or_create(name=index_name) self.ec.es.indices.create(index=index_name, ignore=[400, 404]) self.ids.append(index.pk) def test_bulk_delete(self): url = reverse("v2:index-bulk-delete") response = self.client.post(url, data={"ids": self.ids}, format="json") self.assertTrue(response.status_code == status.HTTP_200_OK) for index_name in self.index_names: self.assertFalse(self.ec.es.indices.exists(index_name)) print_output("test_bulk_delete:response.data", response.data) def tearDown(self) -> None: indices = Index.objects.filter(pk__in=self.ids) names = [index.name for index in indices] if names: indices.delete() for index in names: self.ec.delete_index(index=index, ignore=[400, 404])
class BertTaggerObjectViewTests(APITransactionTestCase): def setUp(self): # Owner of the project self.test_index_name = reindex_test_dataset() self.user = create_test_user('BertTaggerOwner', '*****@*****.**', 'pw') self.admin_user = create_test_user("AdminBertUser", '*****@*****.**', 'pw', superuser=True) self.project = project_creation("BertTaggerTestProject", self.test_index_name, self.user) self.project.users.add(self.user) self.project.users.add(self.admin_user) self.url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}/bert_taggers/' self.project_url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}' self.test_tagger_id = None self.test_multiclass_tagger_id = None self.client.login(username='******', password='******') # Check if TEST_BERT_MODEL is already downloaded available_models = get_downloaded_bert_models( BERT_PRETRAINED_MODEL_DIRECTORY) self.test_model_existed = True if TEST_BERT_MODEL in available_models else False download_bert_requirements(BERT_PRETRAINED_MODEL_DIRECTORY, [TEST_BERT_MODEL], cache_directory=BERT_CACHE_DIR, num_labels=2) # new fact name and value used when applying tagger to index self.new_fact_name = "TEST_BERT_TAGGER_NAME" self.new_multiclass_fact_name = "TEST_BERT_TAGGER_NAME_MC" self.new_fact_value = "TEST_BERT_TAGGER_VALUE" # Create copy of test index self.reindex_url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}/elastic/reindexer/' # Generate name for new index containing random id to make sure it doesn't already exist self.test_index_copy = f"test_apply_bert_tagger_{uuid.uuid4().hex}" self.reindex_payload = { "description": "test index for applying taggers", "indices": [self.test_index_name], "query": json.dumps(TEST_QUERY), "new_index": self.test_index_copy, "fields": TEST_FIELD_CHOICE } resp = self.client.post(self.reindex_url, self.reindex_payload, format='json') print_output( "reindex test index for applying bert tagger:response.data:", resp.json()) self.reindexer_object = Reindexer.objects.get(pk=resp.json()["id"]) self.test_imported_binary_gpu_tagger_id = self.import_test_model( TEST_BERT_TAGGER_BINARY_GPU) self.test_imported_multiclass_gpu_tagger_id = self.import_test_model( TEST_BERT_TAGGER_MULTICLASS_GPU) self.test_imported_binary_cpu_tagger_id = self.import_test_model( TEST_BERT_TAGGER_BINARY_CPU) self.ec = ElasticCore() def import_test_model(self, file_path: str): """Import fine-tuned models for testing.""" print_output("Importing model from file:", file_path) files = {"file": open(file_path, "rb")} import_url = f'{self.url}import_model/' resp = self.client.post(import_url, data={ 'file': open(file_path, "rb") }).json() print_output("Importing test model:", resp) return resp["id"] def test(self): self.run_train_multiclass_bert_tagger_using_fact_name() self.run_train_balanced_multiclass_bert_tagger_using_fact_name() self.run_train_bert_tagger_from_checkpoint_model_bin2bin() self.run_train_bert_tagger_from_checkpoint_model_bin2mc() self.run_train_binary_multiclass_bert_tagger_using_fact_name() self.run_train_binary_multiclass_bert_tagger_using_fact_name_invalid_payload( ) self.run_train_bert_tagger_using_query() self.run_bert_tag_text() self.run_bert_tag_with_imported_gpu_model() self.run_bert_tag_with_imported_cpu_model() self.run_bert_tag_random_doc() self.run_bert_epoch_reports_get() self.run_bert_epoch_reports_post() self.run_bert_get_available_models() self.run_bert_download_pretrained_model() self.run_bert_tag_and_feedback_and_retrain() self.run_bert_model_export_import() self.run_apply_binary_tagger_to_index() self.run_apply_multiclass_tagger_to_index() self.run_apply_tagger_to_index_invalid_input() self.run_bert_tag_text_persistent() self.run_test_that_user_cant_delete_pretrained_model() self.run_test_that_admin_users_can_delete_pretrained_model() self.add_cleanup_files(self.test_tagger_id) self.add_cleanup_folders() def tearDown(self) -> None: res = self.ec.delete_index(self.test_index_copy) self.ec.delete_index(index=self.test_index_name, ignore=[400, 404]) print_output( f"Delete apply_bert_taggers test index {self.test_index_copy}", res) def add_cleanup_files(self, tagger_id: int): tagger_object = BertTaggerObject.objects.get(pk=tagger_id) self.addCleanup(remove_file, tagger_object.model.path) if not TEST_KEEP_PLOT_FILES: self.addCleanup(remove_file, tagger_object.plot.path) def add_cleanup_folders(self): if not self.test_model_existed: test_model_dir = os.path.join( BERT_PRETRAINED_MODEL_DIRECTORY, BertTagger.normalize_name(TEST_BERT_MODEL)) self.addCleanup(remove_folder, test_model_dir) def run_train_multiclass_bert_tagger_using_fact_name(self): """Tests BertTagger training with multiple classes and if a new Task gets created via the signal.""" payload = { "description": "TestBertTaggerObjectTraining", "fact_name": TEST_FACT_NAME, "fields": TEST_FIELD_CHOICE, "indices": [{ "name": self.test_index_name }], "maximum_sample_size": 500, "num_epochs": 2, "max_length": 15, "bert_model": TEST_BERT_MODEL } response = self.client.post(self.url, payload, format='json') print_output( 'test_create_multiclass_bert_tagger_training_and_task_signal:response.data', response.data) # Check if BertTagger gets created self.assertEqual(response.status_code, status.HTTP_201_CREATED) # Give the tagger some time to finish training sleep(5) tagger_id = response.data['id'] self.test_multiclass_tagger_id = tagger_id response = self.client.get(f'{self.url}{tagger_id}/') print_output('test_multiclass_bert_tagger_has_stats:response.data', response.data) for score in ['f1_score', 'precision', 'recall', 'accuracy']: self.assertTrue(isinstance(response.data[score], float)) print_output( 'test_multiclass_bert_tagger_has_classes:response.data.classes', response.data["classes"]) self.assertTrue(isinstance(response.data["classes"], list)) self.assertTrue(len(response.data["classes"]) > 2) self.add_cleanup_files(tagger_id) def run_train_binary_multiclass_bert_tagger_using_fact_name(self): """Tests BertTagger training with binary facts.""" payload = { "description": "Test Bert Tagger training binary multiclass", "fact_name": TEST_FACT_NAME, "fields": TEST_FIELD_CHOICE, "pos_label": TEST_POS_LABEL, "query": json.dumps(TEST_BIN_FACT_QUERY), "indices": [{ "name": self.test_index_name }], "maximum_sample_size": 50, "num_epochs": 1, "max_length": 15, "bert_model": TEST_BERT_MODEL } response = self.client.post(self.url, payload, format='json') print_output( 'test_run_train_binary_multiclass_bert_tagger_using_fact_name:response.data', response.data) # Check if BertTagger gets created self.assertEqual(response.status_code, status.HTTP_201_CREATED) # Give the tagger some time to finish training sleep(5) tagger_id = response.data['id'] response = self.client.get(f'{self.url}{tagger_id}/') print_output( 'test_binary_multiclass_bert_tagger_has_stats:response.data', response.data) for score in ['f1_score', 'precision', 'recall', 'accuracy']: self.assertTrue(isinstance(response.data[score], float)) print_output( 'test_binary_multiclass_bert_tagger_has_classes:response.data.classes', response.data["classes"]) self.assertTrue(isinstance(response.data["classes"], list)) self.assertTrue(len(response.data["classes"]) == 2) self.add_cleanup_files(tagger_id) def run_train_binary_multiclass_bert_tagger_using_fact_name_invalid_payload( self): """Tests BertTagger training with binary facts.""" # Pos label is undefined by the user invalid_payload_1 = { "description": "Test Bert Tagger training binary multiclass invalid", "fact_name": TEST_FACT_NAME, "fields": TEST_FIELD_CHOICE, "query": json.dumps(TEST_BIN_FACT_QUERY), "indices": [{ "name": self.test_index_name }], "maximum_sample_size": 50, "num_epochs": 1, "max_length": 15, "bert_model": TEST_BERT_MODEL } response = self.client.post(self.url, invalid_payload_1, format='json') print_output( 'test_run_train_binary_multiclass_bert_tagger_using_fact_name_missing_pos_label:response.data', response.data) # Check if creating the BertTagger fails with status code 400 self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) # The pos label the user has inserted is not present in the data invalid_payload_2 = { "description": "Test Bert Tagger training binary multiclass invalid", "fact_name": TEST_FACT_NAME, "pos_label": "invalid_fact_val", "fields": TEST_FIELD_CHOICE, "query": json.dumps(TEST_BIN_FACT_QUERY), "indices": [{ "name": self.test_index_name }], "maximum_sample_size": 50, "num_epochs": 1, "max_length": 15, "bert_model": TEST_BERT_MODEL } response = self.client.post(self.url, invalid_payload_2, format='json') print_output( 'test_run_train_binary_multiclass_bert_tagger_using_fact_name_invalid_pos_label:response.data', response.data) # Check if creating the BertTagger fails with status code 400 self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) def run_train_balanced_multiclass_bert_tagger_using_fact_name(self): """Tests balanced BertTagger training with multiple classes and if a new Task gets created via the signal.""" payload = { "description": "TestBalancedBertTaggerObjectTraining", "fact_name": TEST_FACT_NAME, "fields": TEST_FIELD_CHOICE, "indices": [{ "name": self.test_index_name }], "maximum_sample_size": 500, "num_epochs": 2, "max_length": 15, "bert_model": TEST_BERT_MODEL, "balance": True, "use_sentence_shuffle": True, "balance_to_max_limit": True } response = self.client.post(self.url, payload, format='json') print_output( 'test_create_balanced_multiclass_bert_tagger_training_and_task_signal:response.data', response.data) # Check if BertTagger gets created self.assertEqual(response.status_code, status.HTTP_201_CREATED) # Give the tagger some time to finish training sleep(5) tagger_id = response.data['id'] self.test_multiclass_tagger_id = tagger_id response = self.client.get(f'{self.url}{tagger_id}/') print_output( 'test_balanced_multiclass_bert_tagger_has_stats:response.data', response.data) for score in ['f1_score', 'precision', 'recall', 'accuracy']: self.assertTrue(isinstance(response.data[score], float)) print_output( 'test_balanced_multiclass_bert_tagger_has_classes:response.data.classes', response.data["classes"]) self.assertTrue(isinstance(response.data["classes"], list)) self.assertTrue(len(response.data["classes"]) >= 2) num_examples = json.loads(response.data["num_examples"]) print_output( 'test_balanced_bert_tagger_num_examples_correct:num_examples', num_examples) for class_size in num_examples.values(): self.assertTrue(class_size, payload["maximum_sample_size"]) self.add_cleanup_files(tagger_id) def run_train_bert_tagger_using_query(self): """Tests BertTagger training, and if a new Task gets created via the signal.""" payload = { "description": "TestBertTaggerTraining", "fields": TEST_FIELD_CHOICE, "query": json.dumps(TEST_QUERY), "maximum_sample_size": 500, "indices": [{ "name": self.test_index_name }], "num_epochs": 2, "max_length": 15, "bert_model": TEST_BERT_MODEL } print_output(f"training tagger with payload: {payload}", 200) response = self.client.post(self.url, payload, format='json') print_output( 'test_create_binary_bert_tagger_training_and_task_signal:response.data', response.data) # Check if BertTagger gets created self.assertEqual(response.status_code, status.HTTP_201_CREATED) # Give the tagger some time to finish training sleep(5) tagger_id = response.data['id'] response = self.client.get(f'{self.url}{tagger_id}/') print_output('test_binary_bert_tagger_has_stats:response.data', response.data) for score in ['f1_score', 'precision', 'recall', 'accuracy']: self.assertTrue(isinstance(response.data[score], float)) print_output( 'test_binary_bert_tagger_has_classes:response.data.classes', response.data["classes"]) self.assertTrue(isinstance(response.data["classes"], list)) self.assertTrue(len(response.data["classes"]) == 2) # set trained tagger as active tagger self.test_tagger_id = tagger_id self.add_cleanup_files(tagger_id) def run_train_bert_tagger_from_checkpoint_model_bin2bin(self): """Tests training BertTagger from a checkpoint.""" payload = { "description": "Test training binary BertTagger from checkpoint", "fields": TEST_FIELD_CHOICE, "query": json.dumps(TEST_QUERY), "maximum_sample_size": 500, "indices": [{ "name": self.test_index_name }], "num_epochs": 2, "max_length": 12, "bert_model": TEST_BERT_MODEL, "checkpoint_model": self.test_tagger_id } print_output( f"training binary bert tagger from checkpoint with payload: ", payload) response = self.client.post(self.url, payload, format='json') print_output( 'test_train_bert_tagger_from_checkpoint_model_bin2bin:POST:response.data', response.data) # Check if BertTagger gets created self.assertEqual(response.status_code, status.HTTP_201_CREATED) # Give the tagger some time to finish training sleep(5) tagger_id = response.data['id'] response = self.client.get(f'{self.url}{tagger_id}/') print_output( 'test_train_bert_tagger_from_checkpoint_model_bin2bin.has_stats:response.data', response.data) for score in ['f1_score', 'precision', 'recall', 'accuracy']: self.assertTrue(isinstance(response.data[score], float)) print_output( 'test_train_bert_tagger_from_checkpoint_model.has_stats:response.data.classes', response.data["classes"]) self.assertTrue(isinstance(response.data["classes"], list)) self.assertTrue(len(response.data["classes"]) >= 2) self.add_cleanup_files(tagger_id) def run_train_bert_tagger_from_checkpoint_model_bin2mc(self): """Tests training BertTagger from a checkpoint.""" payload = { "description": "Test training multiclass BertTagger from checkpoint", "fields": TEST_FIELD_CHOICE, "fact_name": TEST_FACT_NAME, "maximum_sample_size": 500, "indices": [{ "name": self.test_index_name }], "num_epochs": 2, "max_length": 12, "bert_model": TEST_BERT_MODEL, "checkpoint_model": self.test_tagger_id } print_output( f"training multiclass bert tagger from checkpoint with payload: ", payload) response = self.client.post(self.url, payload, format='json') print_output( 'test_train_bert_tagger_from_checkpoint_model_bin2mc:POST:response.data', response.data) # Check if BertTagger gets created self.assertEqual(response.status_code, status.HTTP_201_CREATED) # Give the tagger some time to finish training sleep(5) tagger_id = response.data['id'] response = self.client.get(f'{self.url}{tagger_id}/') print_output( 'test_train_bert_tagger_from_checkpoint_model_bin2mc.has_stats:response.data', response.data) for score in ['f1_score', 'precision', 'recall', 'accuracy']: self.assertTrue(isinstance(response.data[score], float)) print_output( 'test_train_bert_tagger_from_checkpoint_model_bin2mc.has_classes:response.data.classes', response.data["classes"]) self.assertTrue(isinstance(response.data["classes"], list)) self.assertTrue(len(response.data["classes"]) >= 2) self.add_cleanup_files(tagger_id) def run_bert_tag_with_imported_gpu_model(self): """Test applying imported model trained on GPU.""" payload = {"text": "mine kukele, loll"} response = self.client.post( f'{self.url}{self.test_imported_binary_gpu_tagger_id}/tag_text/', payload) print_output( 'test_bert_tagger_tag_with_imported_gpu_model:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue("probability" in response.data) self.assertTrue("result" in response.data) self.assertTrue("tagger_id" in response.data) # Check if tagger learned to predict self.assertEqual("true", response.data["result"]) self.add_cleanup_files(self.test_imported_binary_gpu_tagger_id) def run_bert_tag_with_imported_cpu_model(self): """Tests applying imported model trained on CPU.""" payload = {"text": "mine kukele, loll"} response = self.client.post( f'{self.url}{self.test_imported_binary_cpu_tagger_id}/tag_text/', payload) print_output( 'test_bert_tagger_tag_with_imported_cpu_model:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue("probability" in response.data) self.assertTrue("result" in response.data) self.assertTrue("tagger_id" in response.data) # Check if tagger learned to predict self.assertEqual("true", response.data["result"]) self.add_cleanup_files(self.test_imported_binary_cpu_tagger_id) def run_bert_tag_text(self): """Tests tag prediction for texts.""" payload = {"text": "mine kukele, loll"} response = self.client.post( f'{self.url}{self.test_tagger_id}/tag_text/', payload) print_output('test_bert_tagger_tag_text:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue("probability" in response.data) self.assertTrue("result" in response.data) self.assertTrue("tagger_id" in response.data) def run_bert_tag_random_doc(self): """Tests the endpoint for the tag_random_doc action""" # Tag with specified fields payload = { "indices": [{ "name": self.test_index_name }], "fields": TEST_FIELD_CHOICE } url = f'{self.url}{self.test_tagger_id}/tag_random_doc/' response = self.client.post(url, format="json", data=payload) print_output('test_bert_tagger_tag_random_doc:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) # Check if response is a dict self.assertTrue(isinstance(response.data, dict)) self.assertTrue("prediction" in response.data) self.assertTrue("document" in response.data) self.assertTrue("probability" in response.data["prediction"]) self.assertTrue("result" in response.data["prediction"]) self.assertTrue("tagger_id" in response.data["prediction"]) # Tag with unspecified fields payload = {"indices": [{"name": self.test_index_name}]} url = f'{self.url}{self.test_tagger_id}/tag_random_doc/' response = self.client.post(url, format="json", data=payload) print_output('test_bert_tagger_tag_random_doc:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) # Check if response is a dict self.assertTrue(isinstance(response.data, dict)) self.assertTrue("prediction" in response.data) self.assertTrue("document" in response.data) self.assertTrue("probability" in response.data["prediction"]) self.assertTrue("result" in response.data["prediction"]) self.assertTrue("tagger_id" in response.data["prediction"]) def run_bert_epoch_reports_get(self): """Tests endpoint for retrieving epoch reports via GET""" url = f'{self.url}{self.test_tagger_id}/epoch_reports/' response = self.client.get(url, format="json") print_output('test_bert_tagger_epoch_reports_get:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) # Check if response is a list self.assertTrue(isinstance(response.data, list)) # Check if first report is not empty self.assertTrue(len(response.data[0]) > 0) def run_bert_epoch_reports_post(self): """Tests endpoint for retrieving epoch reports via GET""" url = f'{self.url}{self.test_tagger_id}/epoch_reports/' payload_1 = {} payload_2 = { "ignore_fields": ["true_positive_rate", "false_positive_rate", "recall"] } response = self.client.post(url, format="json", data=payload_1) print_output( 'test_bert_tagger_epoch_reports_post_ignore_default:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) # Check if response is a list self.assertTrue(isinstance(response.data, list)) # Check if first report contains recall self.assertTrue("recall" in response.data[0]) response = self.client.post(url, format="json", data=payload_2) print_output( 'test_bert_tagger_epoch_reports_post_ignore_custom:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) # Check if response is a list self.assertTrue(isinstance(response.data, list)) # Check if first report does NOT contains recall self.assertTrue("recall" not in response.data[0]) def run_bert_get_available_models(self): """Test endpoint for retrieving available BERT models.""" url = f'{self.url}available_models/' response = self.client.get(url, format="json") print_output('test_bert_tagger_get_available_models:response.data', response.data) available_models = get_downloaded_bert_models( BERT_PRETRAINED_MODEL_DIRECTORY) self.assertEqual(response.status_code, status.HTTP_200_OK) # Check if the endpoint returns currently available models self.assertCountEqual(response.data, available_models) def run_bert_download_pretrained_model(self): """Test endpoint for downloading pretrained BERT model.""" self.client.login(username="******", password='******') url = f'{self.url}download_pretrained_model/' # Test endpoint with valid payload valid_payload = {"bert_model": "prajjwal1/bert-tiny"} response = self.client.post(url, format="json", data=valid_payload) print_output( 'test_bert_tagger_download_pretrained_model_valid_input:response.data', response.data) if ALLOW_BERT_MODEL_DOWNLOADS: self.assertEqual(response.status_code, status.HTTP_200_OK) else: self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) # Test endpoint with invalid payload invalid_payload = {"bert_model": "foo"} response = self.client.post(url, format="json", data=invalid_payload) print_output( 'test_bert_tagger_download_pretrained_model_invalid_input:response.data', response.data) # The endpoint should throw and error self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) def run_bert_model_export_import(self): """Tests endpoint for model export and import""" # test_tagger_id = self.test_tagger_id # retrieve model zip url = f'{self.url}{self.test_tagger_id}/export_model/' response = self.client.get(url) # Post model zip import_url = f'{self.url}import_model/' response = self.client.post(import_url, data={'file': BytesIO(response.content)}) print_output('test_bert_import_model:response.data', import_url) self.assertEqual(response.status_code, status.HTTP_201_CREATED) bert_tagger = BertTaggerObject.objects.get(pk=response.data["id"]) tagger_id = response.data['id'] # Check if the models and plot files exist. resources = bert_tagger.get_resource_paths() for path in resources.values(): file = pathlib.Path(path) self.assertTrue(file.exists()) # Tests the endpoint for the tag_random_doc action""" url = f'{self.url}{bert_tagger.pk}/tag_random_doc/' random_doc_payload = { "indices": [{ "name": self.test_index_name }], "fields": TEST_FIELD_CHOICE } response = self.client.post(url, data=random_doc_payload, format="json") print_output( 'test_bert_tag_random_doc_after_model_import:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue(isinstance(response.data, dict)) self.assertTrue('prediction' in response.data) # remove exported tagger files self.add_cleanup_files(tagger_id) def run_apply_binary_tagger_to_index(self): """Tests applying binary BERT tagger to index using apply_to_index endpoint.""" # Make sure reindexer task has finished while self.reindexer_object.task.status != Task.STATUS_COMPLETED: print_output( 'test_apply_binary_bert_tagger_to_index: waiting for reindexer task to finish, current status:', self.reindexer_object.task.status) sleep(2) url = f'{self.url}{self.test_imported_binary_gpu_tagger_id}/apply_to_index/' payload = { "description": "apply bert tagger to index test task", "new_fact_name": self.new_fact_name, "new_fact_value": self.new_fact_value, "indices": [{ "name": self.test_index_copy }], "fields": TEST_FIELD_CHOICE } response = self.client.post(url, payload, format='json') print_output('test_apply_binary_bert_tagger_to_index:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_201_CREATED) tagger_object = BertTaggerObject.objects.get( pk=self.test_imported_binary_gpu_tagger_id) # Wait til the task has finished while tagger_object.task.status != Task.STATUS_COMPLETED: print_output( 'test_apply_binary_bert_tagger_to_index: waiting for applying tagger task to finish, current status:', tagger_object.task.status) sleep(2) results = ElasticAggregator( indices=[self.test_index_copy]).get_fact_values_distribution( self.new_fact_name) print_output( "test_apply_binary_bert_tagger_to_index:elastic aggerator results:", results) # Check if the expected number of facts are added to the index expected_number_of_facts = 29 self.assertTrue( results[self.new_fact_value] == expected_number_of_facts) self.add_cleanup_files(self.test_imported_binary_gpu_tagger_id) def run_apply_multiclass_tagger_to_index(self): """Tests applying multiclass BERT tagger to index using apply_to_index endpoint.""" # Make sure reindexer task has finished while self.reindexer_object.task.status != Task.STATUS_COMPLETED: print_output( 'test_apply_multiclass_bert_tagger_to_index: waiting for reindexer task to finish, current status:', self.reindexer_object.task.status) sleep(2) url = f'{self.url}{self.test_imported_multiclass_gpu_tagger_id}/apply_to_index/' payload = { "description": "apply bert tagger to index test task", "new_fact_name": self.new_multiclass_fact_name, "new_fact_value": self.new_fact_value, "indices": [{ "name": self.test_index_copy }], "fields": TEST_FIELD_CHOICE } response = self.client.post(url, payload, format='json') print_output( 'test_apply_multiclass_bert_tagger_to_index:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_201_CREATED) tagger_object = BertTaggerObject.objects.get( pk=self.test_imported_multiclass_gpu_tagger_id) # Wait til the task has finished while tagger_object.task.status != Task.STATUS_COMPLETED: print_output( 'test_apply_multiclass_bert_tagger_to_index: waiting for applying tagger task to finish, current status:', tagger_object.task.status) sleep(2) results = ElasticAggregator( indices=[self.test_index_copy]).get_fact_values_distribution( self.new_multiclass_fact_name) print_output( "test_apply_multiclass_bert_tagger_to_index:elastic aggerator results:", results) # Check if the expected facts and the expected number of them are added to the index expected_fact_value = "bar" expected_number_of_facts = 30 self.assertTrue(expected_fact_value in results) self.assertTrue( results[expected_fact_value] == expected_number_of_facts) self.add_cleanup_files(self.test_imported_multiclass_gpu_tagger_id) def run_apply_tagger_to_index_invalid_input(self): """Tests applying multiclass BERT tagger to index using apply_to_index endpoint.""" url = f'{self.url}{self.test_tagger_id}/apply_to_index/' payload = { "description": "apply bert tagger to index test task", "new_fact_name": self.new_fact_name, "new_fact_value": self.new_fact_value, "fields": "invalid_field_format", "bulk_size": 100, "query": json.dumps(TEST_QUERY) } response = self.client.post(url, payload, format='json') print_output('test_invalid_apply_bert_tagger_to_index:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.add_cleanup_files(self.test_tagger_id) def run_bert_tag_and_feedback_and_retrain(self): """Tests feeback extra action.""" # Get basic information to check for previous tagger deletion after retraining has finished. tagger_orm: BertTaggerObject = BertTaggerObject.objects.get( pk=self.test_tagger_id) model_path = pathlib.Path(tagger_orm.model.path) print_output( 'run_bert_tag_and_feedback_and_retrain:assert that previous model doesnt exist', data=model_path.exists()) self.assertTrue(model_path.exists()) payload = { "text": "This is some test text for the Tagger Test", "feedback_enabled": True } tag_text_url = f'{self.url}{self.test_tagger_id}/tag_text/' response = self.client.post(tag_text_url, payload) print_output('test_bert_tag_text_with_feedback:response.data', response.data) self.assertTrue('feedback' in response.data) # generate feedback fb_id = response.data['feedback']['id'] feedback_url = f'{self.url}{self.test_tagger_id}/feedback/' payload = {"feedback_id": fb_id, "correct_result": "FUBAR"} response = self.client.post(feedback_url, payload, format='json') print_output('test_bert_tag_text_feedback:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue(response.data) self.assertTrue('success' in response.data) # sleep for a sec to allow elastic to finish its bussiness sleep(1) # list feedback feedback_list_url = f'{self.url}{self.test_tagger_id}/feedback/' response = self.client.get(feedback_list_url) print_output('test_tag_text_list_feedback:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue(response.data) self.assertTrue(len(response.data) > 0) # retrain model url = f'{self.url}{self.test_tagger_id}/retrain_tagger/' response = self.client.post(url) print_output('test_bert_tagger_feedback:retrain', response.data) # test tagging again for this model payload = {"text": "This is some test text for the Tagger Test"} tag_text_url = f'{self.url}{self.test_tagger_id}/tag_text/' response = self.client.post(tag_text_url, payload) print_output( 'test_bert_tagger_feedback_retrained_tag_doc:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue('result' in response.data) self.assertTrue('probability' in response.data) # Ensure that previous tagger is deleted properly. print_output( 'test_model_retrain:assert that previous model doesnt exist', data=model_path.exists()) self.assertFalse(model_path.exists()) # Ensure that the freshly created model wasn't deleted. tagger_orm.refresh_from_db() self.assertNotEqual(tagger_orm.model.path, str(model_path)) # delete feedback feedback_delete_url = f'{self.url}{self.test_tagger_id}/feedback/' response = self.client.delete(feedback_delete_url) print_output('test_bert_tagger_tag_doc_delete_feedback:response.data', response.data) # sleep for a sec to allow elastic to finish its bussiness sleep(1) # list feedback again to make sure its emtpy feedback_list_url = f'{self.url}{self.test_tagger_id}/feedback/' response = self.client.get(feedback_list_url) print_output( 'test_bert_tagger_tag_doc_list_feedback_after_delete:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue(len(response.data) == 0) # remove created index feedback_index_url = f'{self.project_url}/feedback/' response = self.client.delete(feedback_index_url) print_output('test_bert_tagger_delete_feedback_index:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue('success' in response.data) self.add_cleanup_files(self.test_tagger_id) def run_bert_tag_text_persistent(self): """Tests tag prediction for texts using persistent models.""" payload = {"text": "mine kukele, loll", "persistent": True} # First try start_1 = time_ns() response = self.client.post( f'{self.url}{self.test_tagger_id}/tag_text/', payload) end_1 = time_ns() - start_1 # Give time for persistent taggers to load sleep(1) # Second try start_2 = time_ns() response = self.client.post( f'{self.url}{self.test_tagger_id}/tag_text/', payload) end_2 = time_ns() - start_2 # Test if second attempt faster print_output('test_bert_tagger_persistent speed:', (end_1, end_2)) assert end_2 < end_1 def run_test_that_user_cant_delete_pretrained_model(self): self.client.login(username='******', password='******') url = reverse("v2:bert_tagger-delete-pretrained-model", kwargs={"project_pk": self.project.pk}) resp = self.client.post(url, data={"model_name": "EMBEDDIA/finest-bert"}) print_output( "run_test_that_user_cant_delete_pretrained_model:response.data", data=resp.data) self.assertTrue(resp.status_code == status.HTTP_401_UNAUTHORIZED or resp.status_code == status.HTTP_403_FORBIDDEN) def run_test_that_admin_users_can_delete_pretrained_model(self): self.client.login(username="******", password='******') url = reverse("v2:bert_tagger-delete-pretrained-model", kwargs={"project_pk": self.project.pk}) model_name = "prajjwal1/bert-tiny" file_name = BertTagger.normalize_name(model_name) model_path = pathlib.Path(BERT_PRETRAINED_MODEL_DIRECTORY) / file_name self.assertTrue(model_path.exists() is True) resp = self.client.post(url, data={"model_name": model_name}) print_output( "run_test_that_admin_users_can_delete_pretrained_model:response.data", data=resp.data) self.assertTrue(resp.status_code == status.HTTP_200_OK) self.assertTrue(model_path.exists() is False)
class ReindexerViewTests(APITransactionTestCase): def setUp(self): """ user needs to be admin, because of changed indices permissions """ self.test_index_name = reindex_test_dataset() self.default_password = '******' self.default_username = '******' self.user = create_test_user(self.default_username, '*****@*****.**', self.default_password) # create admin to test indices removal from project self.admin = create_test_user(name='admin', password='******') self.admin.is_superuser = True self.admin.save() self.project = project_creation("ReindexerTestProject", self.test_index_name, self.user) self.project.users.add(self.user) self.ec = ElasticCore() self.client.login(username=self.default_username, password=self.default_password) self.new_index_name = f"{TEST_FIELD}_2" def tearDown(self) -> None: self.ec.delete_index(index=self.test_index_name, ignore=[400, 404]) def test_run(self): existing_new_index_payload = { "description": "TestWrongField", "indices": [self.test_index_name], "new_index": REINDEXER_TEST_INDEX, # index created for test purposes } wrong_fields_payload = { "description": "TestWrongField", "indices": [self.test_index_name], "new_index": TEST_INDEX_REINDEX, "fields": ['12345'], } wrong_indices_payload = { "description": "TestWrongIndex", "indices": ["Wrong_Index"], "new_index": TEST_INDEX_REINDEX, } pick_fields_payload = { "description": "TestManyReindexerFields", # this has a problem with possible name duplicates "fields": [TEST_FIELD, 'comment_content_clean.text', 'texta_facts'], "indices": [self.test_index_name], "new_index": TEST_INDEX_REINDEX, } # duplicate name problem? # if you want to actually test it, add an index to indices and project indices join_indices_fields_payload = { "description": "TestReindexerJoinFields", "indices": [self.test_index_name], "new_index": TEST_INDEX_REINDEX, } test_query_payload = { "description": "TestQueryFiltering", "scroll_size": 100, "indices": [self.test_index_name], "new_index": TEST_INDEX_REINDEX, "query": json.dumps(TEST_QUERY) } random_docs_payload = { "description": "TestReindexerRandomFields", "indices": [self.test_index_name], "new_index": TEST_INDEX_REINDEX, "random_size": 500, } update_field_type_payload = { "description": "TestReindexerUpdateFieldType", "fields": [], "indices": [self.test_index_name], "new_index": TEST_INDEX_REINDEX, "field_type": [ { "path": "comment_subject", "field_type": "long", "new_path_name": "CHANGED_NAME" }, { "path": "comment_content_lemmas", "field_type": "fact", "new_path_name": "CHANGED_TOO" }, { "path": "comment_content_clean.stats.text_length", "field_type": "boolean", "new_path_name": "CHANGED_AS_WELL" }, ], } for REINDEXER_VALIDATION_TEST_INDEX in ( REINDEXER_VALIDATION_TEST_INDEX_1, REINDEXER_VALIDATION_TEST_INDEX_2, REINDEXER_VALIDATION_TEST_INDEX_3, REINDEXER_VALIDATION_TEST_INDEX_4, REINDEXER_VALIDATION_TEST_INDEX_5, REINDEXER_VALIDATION_TEST_INDEX_6): new_index_validation_payload = { "description": "TestNewIndexValidation", "indices": [self.test_index_name], "new_index": REINDEXER_VALIDATION_TEST_INDEX } url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}/elastic/reindexer/' self.check_new_index_validation(url, new_index_validation_payload) for payload in ( existing_new_index_payload, wrong_indices_payload, wrong_fields_payload, pick_fields_payload, join_indices_fields_payload, test_query_payload, random_docs_payload, update_field_type_payload, ): url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}/elastic/reindexer/' self.run_create_reindexer_task_signal(self.project, url, payload) def run_create_reindexer_task_signal(self, project, url, payload, overwrite=False): """ Tests the endpoint for a new Reindexer task, and if a new Task gets created via the signal checks if new_index was removed """ try: self.ec.delete_index(TEST_INDEX_REINDEX) except: print(f'{TEST_INDEX_REINDEX} was not deleted') response = self.client.post(url, payload, format='json') print_output('run_create_reindexer_task_signal:response.data', response.data) self.check_update_forbidden(url, payload) self.is_new_index_created_if_yes_remove(response, payload, project) self.is_reindexed_index_added_to_project_if_yes_remove( response, payload['new_index'], project) assert TEST_INDEX_REINDEX not in ElasticCore().get_indices() def check_new_index_validation(self, url, new_index_validation_payload): response = self.client.post(url, new_index_validation_payload, format='json') print_output('new_index_validation:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.data["detail"].code, "invalid_index_name") def is_new_index_created_if_yes_remove(self, response, payload, project): """ Check if new_index gets created Check if new_index gets re-indexed and completed remove test new_index """ if project.get_indices() is None or response.exception: self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) else: self.assertEqual(response.status_code, status.HTTP_201_CREATED) created_reindexer = Reindexer.objects.get(id=response.data['id']) print_output("Re-index task status: ", created_reindexer.task.status) self.assertEqual(created_reindexer.task.status, Task.STATUS_COMPLETED) # self.check_positive_doc_count() new_index = response.data['new_index'] delete_response = self.ec.delete_index(new_index) print_output("Reindexer Test index remove status", delete_response) def is_reindexed_index_added_to_project_if_yes_remove( self, response, new_index, project): # project resource user is not supposed to have indices remove permission, so use admin self.client.login(username='******', password='******') url = f'{TEST_VERSION_PREFIX}/projects/{project.id}/' check = self.client.get(url, format='json') if response.status_code == 201: assert new_index in [ index["name"] for index in check.data['indices'] ] print_output('Re-indexed index added to project', check.data) index_pk = Index.objects.get(name=new_index).pk remove_index_url = reverse( f"{VERSION_NAMESPACE}:project-remove-indices", kwargs={"pk": self.project.pk}) remove_response = self.client.post(remove_index_url, {"indices": [index_pk]}, format='json') print_output("Re-indexed index removed from project", remove_response.status_code) self.delete_reindexing_task(project, response) if response.status_code == 400: print_output('Re-indexed index not added to project', check.data) check = self.client.get(url, format='json') assert new_index not in [ index["name"] for index in check.data['indices'] ] # Log in with project user again self.client.login(username=self.default_username, password=self.default_password) def validate_fields(self, project, payload): project_fields = self.ec.get_fields(project.get_indices()) project_field_paths = [field["path"] for field in project_fields] for field in payload['fields']: if field not in project_field_paths: return False return True def validate_indices(self, project, payload): for index in payload['indices']: if index not in project.get_indices(): return False return True def check_positive_doc_count(self): # current reindexing tests require approx 2 seconds delay sleep(5) count_new_documents = ElasticSearcher( indices=TEST_INDEX_REINDEX).count() print_output("Bulk add doc count", count_new_documents) assert count_new_documents > 0 def check_update_forbidden(self, url, payload): put_response = self.client.put(url, payload, format='json') patch_response = self.client.patch(url, payload, format='json') print_output("put_response.data", put_response.data) print_output("patch_response.data", patch_response.data) self.assertEqual(put_response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEqual(patch_response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) def delete_reindexing_task(self, project, response): """ test delete reindex task """ task_url = response.data['url'] get_response = self.client.get(task_url) self.assertEqual(get_response.status_code, status.HTTP_200_OK) delete_response = self.client.delete(task_url, format='json') self.assertEqual(delete_response.status_code, status.HTTP_204_NO_CONTENT) get_response = self.client.get(task_url) self.assertEqual(get_response.status_code, status.HTTP_404_NOT_FOUND) def test_that_changing_field_names_works(self): payload = { "description": "RenameFieldName", "new_index": self.new_index_name, "fields": [TEST_FIELD], "field_type": [{ "path": TEST_FIELD, "new_path_name": TEST_FIELD_RENAMED, "field_type": "text" }], "indices": [self.test_index_name], "add_facts_mapping": True } # Reindex the test index into a new one. url = reverse("v2:reindexer-list", kwargs={"project_pk": self.project.pk}) reindex_response = self.client.post(url, data=payload, format='json') print_output('test_that_changing_field_names_works:response.data', reindex_response.data) # Check that the fields have been changed. es = ElasticSearcher(indices=[self.new_index_name]) for document in es: self.assertTrue(TEST_FIELD not in document) self.assertTrue(TEST_FIELD_RENAMED in document) # Manual clean up. es.core.delete_index(self.new_index_name)
class IndexSplitterViewTests(APITransactionTestCase): def setUp(self): """ User needs to be admin, because of changed indices permissions. """ self.test_index_name = reindex_test_dataset() self.default_password = '******' self.default_username = '******' self.user = create_test_user(self.default_username, '*****@*****.**', self.default_password) self.admin = create_test_user(name='admin', password='******') self.admin.is_superuser = True self.admin.save() self.project = project_creation("IndexSplittingTestProject", self.test_index_name, self.user) self.project.users.add(self.user) self.url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}/elastic/index_splitter/' self.client.login(username=self.default_username, password=self.default_password) self.ec = ElasticCore() self.FACT = "TEEMA" def test_create_splitter_object_and_task_signal(self): payload = { "description": "Random index splitting", "indices": [{ "name": self.test_index_name }], "train_index": INDEX_SPLITTING_TRAIN_INDEX, "test_index": INDEX_SPLITTING_TEST_INDEX, "distribution": "random", "test_size": 20 } response = self.client.post(self.url, json.dumps(payload), content_type='application/json') print_output( 'test_create_splitter_object_and_task_signal:response.data', response.data) splitter_obj = IndexSplitter.objects.get(id=response.data['id']) print_output("indices:", splitter_obj.get_indices()) # Check if IndexSplitter object gets created self.assertEqual(response.status_code, status.HTTP_201_CREATED) # Check if Task gets created self.assertTrue(splitter_obj.task is not None) print_output("status of IndexSplitter's Task object", splitter_obj.task.status) # Check if Task gets completed self.assertEqual(splitter_obj.task.status, Task.STATUS_COMPLETED) sleep(5) original_count = ElasticSearcher(indices=self.test_index_name).count() test_count = ElasticSearcher( indices=INDEX_SPLITTING_TEST_INDEX).count() train_count = ElasticSearcher( indices=INDEX_SPLITTING_TRAIN_INDEX).count() print_output('original_count, test_count, train_count', [original_count, test_count, train_count]) def test_create_random_split(self): payload = { "description": "Random index splitting", "indices": [{ "name": self.test_index_name }], "train_index": INDEX_SPLITTING_TRAIN_INDEX, "test_index": INDEX_SPLITTING_TEST_INDEX, "distribution": "random", "test_size": 20 } response = self.client.post(self.url, data=payload) print_output('test_create_random_split:response.data', response.data) splitter_obj = IndexSplitter.objects.get(id=response.data['id']) # Assert Task gets completed self.assertEqual(Task.STATUS_COMPLETED, Task.STATUS_COMPLETED) print_output("Task status", Task.STATUS_COMPLETED) sleep(5) original_count = ElasticSearcher(indices=self.test_index_name).count() test_count = ElasticSearcher( indices=INDEX_SPLITTING_TEST_INDEX).count() train_count = ElasticSearcher( indices=INDEX_SPLITTING_TRAIN_INDEX).count() print_output('original_count, test_count, train_count', [original_count, test_count, train_count]) # To avoid any inconsistencies caused by rounding assume sizes are between small limits self.assertTrue(self.is_between_limits(test_count, original_count, 0.2)) self.assertTrue( self.is_between_limits(train_count, original_count, 0.8)) def test_create_original_split(self): payload = { "description": "Original index splitting", "indices": [{ "name": self.test_index_name }], "train_index": INDEX_SPLITTING_TRAIN_INDEX, "test_index": INDEX_SPLITTING_TEST_INDEX, "distribution": "original", "test_size": 20, "fact": self.FACT } response = self.client.post(self.url, data=payload) print_output('test_create_original_split:response.data', response.data) splitter_obj = IndexSplitter.objects.get(id=response.data['id']) # Assert Task gets completed self.assertEqual(Task.STATUS_COMPLETED, Task.STATUS_COMPLETED) print_output("Task status", Task.STATUS_COMPLETED) sleep(5) original_distribution = ElasticAggregator( indices=self.test_index_name).get_fact_values_distribution( self.FACT) test_distribution = ElasticAggregator( indices=INDEX_SPLITTING_TEST_INDEX).get_fact_values_distribution( self.FACT) train_distribution = ElasticAggregator( indices=INDEX_SPLITTING_TRAIN_INDEX).get_fact_values_distribution( self.FACT) print_output( 'original_dist, test_dist, train_dist', [original_distribution, test_distribution, train_distribution]) for label, quant in original_distribution.items(): self.assertTrue( self.is_between_limits(test_distribution[label], quant, 0.2)) self.assertTrue( self.is_between_limits(train_distribution[label], quant, 0.8)) def test_create_equal_split(self): payload = { "description": "Original index splitting", "indices": [{ "name": self.test_index_name }], "train_index": INDEX_SPLITTING_TRAIN_INDEX, "test_index": INDEX_SPLITTING_TEST_INDEX, "distribution": "equal", "test_size": 20, "fact": self.FACT } response = self.client.post(self.url, data=payload) print_output('test_create_equal_split:response.data', response.data) splitter_obj = IndexSplitter.objects.get(id=response.data['id']) # Assert Task gets completed self.assertEqual(Task.STATUS_COMPLETED, Task.STATUS_COMPLETED) print_output("Task status", Task.STATUS_COMPLETED) sleep(5) original_distribution = ElasticAggregator( indices=self.test_index_name).get_fact_values_distribution( self.FACT) test_distribution = ElasticAggregator( indices=INDEX_SPLITTING_TEST_INDEX).get_fact_values_distribution( self.FACT) train_distribution = ElasticAggregator( indices=INDEX_SPLITTING_TRAIN_INDEX).get_fact_values_distribution( self.FACT) print_output( 'original_dist, test_dist, train_dist', [original_distribution, test_distribution, train_distribution]) for label, quant in original_distribution.items(): if (quant > 20): self.assertEqual(test_distribution[label], 20) self.assertEqual(train_distribution[label], quant - 20) else: self.assertEqual(test_distribution[label], quant) self.assertTrue(label not in train_distribution) def test_create_custom_split(self): custom_distribution = {"FUBAR": 10, "bar": 15} payload = { "description": "Original index splitting", "indices": [{ "name": self.test_index_name }], "train_index": INDEX_SPLITTING_TRAIN_INDEX, "test_index": INDEX_SPLITTING_TEST_INDEX, "distribution": "custom", "fact": self.FACT, "custom_distribution": json.dumps(custom_distribution) } response = self.client.post(self.url, data=payload, format="json") print_output('test_create_custom_split:response.data', response.data) splitter_obj = IndexSplitter.objects.get(id=response.data['id']) # Assert Task gets completed self.assertEqual(Task.STATUS_COMPLETED, Task.STATUS_COMPLETED) print_output("Task status", Task.STATUS_COMPLETED) sleep(5) original_distribution = ElasticAggregator( indices=self.test_index_name).get_fact_values_distribution( self.FACT) test_distribution = ElasticAggregator( indices=INDEX_SPLITTING_TEST_INDEX).get_fact_values_distribution( self.FACT) train_distribution = ElasticAggregator( indices=INDEX_SPLITTING_TRAIN_INDEX).get_fact_values_distribution( self.FACT) print_output( 'original_dist, test_dist, train_dist', [original_distribution, test_distribution, train_distribution]) for label, quant in custom_distribution.items(): self.assertEqual(test_distribution[label], min(quant, original_distribution[label])) for label in original_distribution.keys(): if label not in custom_distribution: self.assertTrue(label not in test_distribution) self.assertTrue(original_distribution[label], train_distribution[label]) def test_create_original_split_fact_value_given(self): payload = { "description": "Original index splitting", "indices": [{ "name": self.test_index_name }], "train_index": INDEX_SPLITTING_TRAIN_INDEX, "test_index": INDEX_SPLITTING_TEST_INDEX, "distribution": "original", "test_size": 20, "fact": self.FACT, "str_val": "FUBAR" } response = self.client.post(self.url, data=payload, format="json") print_output( 'test_create_original_split_fact_value_given:response.data', response.data) splitter_obj = IndexSplitter.objects.get(id=response.data['id']) sleep(5) original_distribution = ElasticAggregator( indices=self.test_index_name).get_fact_values_distribution( self.FACT) test_distribution = ElasticAggregator( indices=INDEX_SPLITTING_TEST_INDEX).get_fact_values_distribution( self.FACT) train_distribution = ElasticAggregator( indices=INDEX_SPLITTING_TRAIN_INDEX).get_fact_values_distribution( self.FACT) print_output( 'original_dist, test_dist, train_dist', [original_distribution, test_distribution, train_distribution]) for label, quant in original_distribution.items(): if label == "FUBAR": self.assertTrue( self.is_between_limits(test_distribution[label], quant, 0.2)) self.assertTrue( self.is_between_limits(train_distribution[label], quant, 0.8)) def test_query_given(self): payload = { "description": "Original index splitting", "indices": [{ "name": self.test_index_name }], "train_index": INDEX_SPLITTING_TRAIN_INDEX, "test_index": INDEX_SPLITTING_TEST_INDEX, "distribution": "original", "test_size": 20, "fact": self.FACT, "str_val": "bar", "query": json.dumps(TEST_QUERY) } response = self.client.post(self.url, data=payload, format="json") print_output('test_query_given:response.data', response.data) original_distribution = ElasticAggregator( indices=self.test_index_name).get_fact_values_distribution( self.FACT) test_distribution = ElasticAggregator( indices=INDEX_SPLITTING_TEST_INDEX).get_fact_values_distribution( self.FACT) train_distribution = ElasticAggregator( indices=INDEX_SPLITTING_TRAIN_INDEX).get_fact_values_distribution( self.FACT) print_output( 'original_dist, test_dist, train_dist', [original_distribution, test_distribution, train_distribution]) self.assertTrue("bar" in test_distribution) self.assertTrue("bar" in train_distribution) self.assertTrue("foo" not in train_distribution and "foo" not in test_distribution) self.assertTrue("FUBAR" not in train_distribution and "FUBAR" not in test_distribution) def tearDown(self): self.ec.delete_index(index=self.test_index_name, ignore=[400, 404]) res = self.ec.delete_index(INDEX_SPLITTING_TEST_INDEX) print_output('attempt to delete test index:', res) res = self.ec.delete_index(INDEX_SPLITTING_TRAIN_INDEX) print_output('attempt to delete train index:', res) def is_between_limits(self, value, base, ratio): return value <= base * ratio + 1 and value >= base * ratio - 1 # There used to be a bug in which objects were flattened in the split index unintentionally. def test_that_split_index_with_nested_field_still_has_nested_field(self): payload = { "description": "Random index splitting", "indices": [{ "name": self.test_index_name }], "train_index": INDEX_SPLITTING_TRAIN_INDEX, "test_index": INDEX_SPLITTING_TEST_INDEX, "distribution": "random", "test_size": 20 } response = self.client.post(self.url, data=payload, format="json") print_output( 'test_that_split_index_with_nested_field_still_has_nested_field:response.data', response.data) at_least_once = False es = ElasticSearcher( indices=[INDEX_SPLITTING_TEST_INDEX, INDEX_SPLITTING_TEST_INDEX], field_data=[TEST_INDEX_OBJECT_FIELD], flatten=False) for item in es: data = item.get(TEST_INDEX_OBJECT_FIELD, None) if data: self.assertTrue(isinstance(data, dict)) at_least_once = True self.assertTrue(at_least_once)
def tearDown(self) -> None: ec = ElasticCore() ec.delete_index(self.test_index_name)
class MLPIndexProcessing(APITransactionTestCase): def setUp(self): self.test_index_name = reindex_test_dataset() self.ec = ElasticCore() self.user = create_test_user('mlpUser', '*****@*****.**', 'pw') self.project = project_creation("mlpTestProject", self.test_index_name, self.user) self.project.users.add(self.user) self.client.login(username='******', password='******') self.url = reverse(f"{VERSION_NAMESPACE}:mlp_index-list", kwargs={"project_pk": self.project.pk}) def tearDown(self) -> None: self.ec.delete_index(self.test_index_name, ignore=[400, 404]) def _assert_mlp_contents(self, hit: dict, test_field: str): self.assertTrue(f"{test_field}_mlp.lemmas" in hit) self.assertTrue(f"{test_field}_mlp.pos_tags" in hit) self.assertTrue(f"{test_field}_mlp.text" in hit) self.assertTrue(f"{test_field}_mlp.language.analysis" in hit) self.assertTrue(f"{test_field}_mlp.language.detected" in hit) def test_index_processing(self): query_string = "inimene" payload = { "description": "TestingIndexProcessing", "fields": [TEST_FIELD], "query": json.dumps( {'query': { 'match': { 'comment_content_lemmas': query_string } }}, ensure_ascii=False) } response = self.client.post(self.url, data=payload, format="json") print_output("test_index_processing:response.data", response.data) # Check if MLP was applied to the documents properly. s = ElasticSearcher(indices=[self.test_index_name], output=ElasticSearcher.OUT_DOC, query=payload["query"]) for hit in s: self._assert_mlp_contents(hit, TEST_FIELD) def _check_for_if_query_correct(self, hit: dict, field_name: str, query_string: str): text = hit[field_name] self.assertTrue(query_string in text) def test_payload_without_fields_value(self): query_string = "inimene" payload = { "description": "TestingIndexProcessing", "query": json.dumps( {'query': { 'match': { 'comment_content_lemmas': query_string } }}, ensure_ascii=False) } response = self.client.post(self.url, data=payload, format="json") print_output("test_payload_without_fields_value:response.data", response.data) self.assertTrue(response.status_code == status.HTTP_400_BAD_REQUEST) self.assertTrue( response.data["fields"][0] == "This field is required.") def test_payload_with_empty_fields_value(self): query_string = "inimene" payload = { "description": "TestingIndexProcessing", "query": json.dumps( {'query': { 'match': { 'comment_content_lemmas': query_string } }}, ensure_ascii=False), "fields": [] } response = self.client.post(self.url, data=payload, format="json") print_output("test_payload_without_fields_value:response.data", response.data) self.assertTrue(response.status_code == status.HTTP_400_BAD_REQUEST) self.assertTrue( response.data["fields"][0] == "This list may not be empty.") def test_payload_with_invalid_field_value(self): query_string = "inimene" payload = { "description": "TestingIndexProcessing", "fields": ["this_field_does_not_exist"], "query": json.dumps( {'query': { 'match': { 'comment_content_lemmas': query_string } }}, ensure_ascii=False) } response = self.client.post(self.url, data=payload, format="json") print_output("test_payload_with_invalid_field_value:response.data", response.data) self.assertTrue(response.status_code == status.HTTP_400_BAD_REQUEST) def test_applying_mlp_on_two_indices(self): query_string = "inimene" indices = [f"texta_test_{uuid.uuid1()}", f"texta_test_{uuid.uuid1()}"] for index in indices: self.ec.es.indices.create(index=index, ignore=[400, 404]) self.ec.es.index(index=index, body={"text": "obscure content to parse!"}) index, is_created = Index.objects.get_or_create(name=index) self.project.indices.add(index) payload = { "description": "TestingIndexProcessing", "fields": ["text"], "indices": [{ "name": index } for index in indices], "query": json.dumps( {'query': { 'match': { 'comment_content_lemmas': query_string } }}, ensure_ascii=False) } response = self.client.post(self.url, data=payload, format="json") print_output("test_applying_mlp_on_two_indices:response.data", response.data) self.assertEqual(response.status_code, status.HTTP_201_CREATED) for index in indices: self.ec.es.indices.delete(index=index, ignore=[400, 404])
def tearDown(self): # delete created indices ec = ElasticCore() for index in self.created_indices: delete_response = ec.delete_index(index) print_output("Remove index:response", delete_response)
def __remove_reindexed_test_index(self): ec = ElasticCore() result = ec.delete_index(index=self.test_index_name, ignore=[400, 404]) print_output( f"Deleting ProjectViewTests test index {self.test_index_name}", result)
class TestDocparserAPIView(APITestCase): def setUp(self) -> None: self.test_index_name = reindex_test_dataset() self.user = create_test_user('Owner', '*****@*****.**', 'pw') self.unauthorized_user = create_test_user('unauthorized', '*****@*****.**', 'pw') self.file_name = "d41d8cd98f00b204e9800998ecf8427e.txt" self.project = project_creation("test_doc_parser", index_title=None, author=self.user) self.project.users.add(self.user) self.unauth_project = project_creation("unauth_project", index_title=None, author=self.user) self.file = SimpleUploadedFile("text.txt", b"file_content", content_type="text/html") self.client.login(username='******', password='******') self._basic_pipeline_functionality() self.file_path = self._get_file_path() self.ec = ElasticCore() def tearDown(self) -> None: self.ec.delete_index(index=self.test_index_name, ignore=[400, 404]) def _get_file_path(self): path = pathlib.Path(settings.RELATIVE_PROJECT_DATA_PATH) / str(self.project.pk) / "docparser" / self.file_name return path def _basic_pipeline_functionality(self): url = reverse(f"{VERSION_NAMESPACE}:docparser") payload = { "file": self.file, "project_id": self.project.pk, "indices": [self.test_index_name], "file_name": self.file_name } response = self.client.post(url, data=payload) print_output("_basic_pipeline_functionality:response.data", response.data) self.assertTrue(response.status_code == status.HTTP_200_OK) def test_file_appearing_in_proper_structure(self): print_output("test_file_appearing_in_proper_structure", self.file_path.exists()) self.assertTrue(self.file_path.exists()) def test_being_rejected_without_login(self): url = reverse(f"{VERSION_NAMESPACE}:docparser") self.client.logout() payload = { "file": self.file, "project_id": self.project.pk, "indices": [self.test_index_name] } response = self.client.post(url, data=payload) print_output("test_being_rejected_without_login:response.data", response.data) response_code = response.status_code print_output("test_being_rejected_without_login:response.status_code", response_code) self.assertTrue((response_code == status.HTTP_403_FORBIDDEN) or (response_code == status.HTTP_401_UNAUTHORIZED)) def test_being_rejected_with_wrong_project_id(self): url = reverse(f"{VERSION_NAMESPACE}:docparser") payload = { "file": self.file, "project_id": self.unauth_project.pk, "indices": [self.test_index_name] } self.unauth_project.users.remove(self.user) response = self.client.post(url, data=payload) print_output("test_being_rejected_with_wrong_project_id:response.data", response.data) self.assertTrue(response.status_code == status.HTTP_403_FORBIDDEN) def test_indices_being_added_into_the_project(self): project = Project.objects.get(pk=self.project.pk) indices = project.indices.all() added_index = indices.filter(name=self.test_index_name) self.assertTrue(added_index.count() == 1) print_output("test_indices_being_added_into_the_project", True) def test_that_serving_media_works_for_authenticated_users(self): file_name = self.file_path.name url = reverse("protected_serve", kwargs={"project_id": self.project.pk, "application": "docparser", "file_name": file_name}) response = self.client.get(url) print_output("test_that_serving_media_works_for_authenticated_users", True) self.assertTrue(response.status_code == status.HTTP_200_OK) def test_that_serving_media_doesnt_work_for_unauthenticated_users(self): self.client.logout() file_name = self.file_path.name url = reverse("protected_serve", kwargs={"project_id": self.project.pk, "application": "docparser", "file_name": file_name}) response = self.client.get(url) print_output("test_that_serving_media_doesnt_work_for_unauthenticated_users", True) self.assertTrue(response.status_code == status.HTTP_302_FOUND) def test_media_access_for_unauthorized_projects(self): self.client.login(username="******", password="******") file_name = self.file_path.name url = reverse("protected_serve", kwargs={"project_id": self.project.pk, "application": "docparser", "file_name": file_name}) response = self.client.get(url) print_output("test_media_access_for_unauthorized_projects", True) self.assertTrue(response.status_code == status.HTTP_403_FORBIDDEN) def test_that_saved_file_size_isnt_zero(self): """ Necessary because of a prior bug where the wrapper would save a file with the right name but not it's contents. """ import time time.sleep(10) file_size = os.path.getsize(self.file_path) self.assertTrue(file_size > 1) print_output("test_that_saved_file_size_isnt_zero::file_size:int", file_size) def test_payload_with_empty_indices(self): url = reverse(f"{VERSION_NAMESPACE}:docparser") payload = { "file": SimpleUploadedFile("text.txt", b"file_content", content_type="text/html"), "project_id": self.project.pk, "file_name": self.file_name } response = self.client.post(url, data=payload) print_output("_basic_pipeline_functionality:response.data", response.data) self.assertTrue(response.status_code == status.HTTP_200_OK)
class DocumentImporterAPITestCase(APITestCase): def setUp(self): self.test_index_name = reindex_test_dataset() self.user = create_test_user('first_user', '*****@*****.**', 'pw') self.project = project_creation("DocumentImporterAPI", self.test_index_name, self.user) self.validation_project = project_creation("validation_project", "random_index_name", self.user) self.document_id = random.randint(10000000, 90000000) self.uuid = uuid.uuid1() self.source = {"hello": "world", "uuid": self.uuid} self.document = {"_index": self.test_index_name, "_id": self.document_id, "_source": self.source} self.target_field_random_key = uuid.uuid1() self.target_field = f"{self.target_field_random_key}_court_case" self.ec = ElasticCore() self.client.login(username='******', password='******') self._check_inserting_documents() def _check_inserting_documents(self): url = reverse(f"{VERSION_NAMESPACE}:document_import", kwargs={"pk": self.project.pk}) response = self.client.post(url, data={"documents": [self.document], "split_text_in_fields": []}, format="json") self.assertTrue(response.status_code == status.HTTP_200_OK) document = self.ec.es.get(id=self.document_id, index=self.test_index_name) print_output("_check_inserting_documents:response.data", response.data) self.assertTrue(document["_source"]) def tearDown(self) -> None: self.ec.delete_index(index=self.test_index_name, ignore=[400, 404]) query = Search().query(Q("exists", field=self.target_field)).to_dict() self.ec.es.delete_by_query(index="*", body=query, wait_for_completion=True) def test_adding_documents_to_false_project(self): url = reverse(f"{VERSION_NAMESPACE}:document_import", kwargs={"pk": self.validation_project.pk}) self.validation_project.users.remove(self.user) response = self.client.post(url, data={"documents": [self.document]}, format="json") self.assertTrue(response.status_code == status.HTTP_403_FORBIDDEN) print_output("test_adding_documents_to_false_project:response.data", response.data) def test_adding_documents_to_false_index(self): url = reverse(f"{VERSION_NAMESPACE}:document_import", kwargs={"pk": self.project.pk}) index_name = "wrong_index" response = self.client.post(url, data={"documents": [{"_index": index_name, "_source": self.document}]}, format="json") try: self.ec.es.get(id=self.document_id, index=index_name) except NotFoundError: print_output("test_adding_documents_to_false_index:response.data", response.data) else: raise Exception("Elasticsearch indexed a document it shouldn't have!") def test_updating_document(self): url = reverse(f"{VERSION_NAMESPACE}:document_instance", kwargs={"pk": self.project.pk, "index": self.test_index_name, "document_id": self.document_id}) response = self.client.patch(url, data={"hello": "night", "goodbye": "world"}) self.assertTrue(response.status_code == status.HTTP_200_OK) document = self.ec.es.get(index=self.test_index_name, id=self.document_id)["_source"] self.assertTrue(document["hello"] == "night" and document["goodbye"] == "world") print_output("test_updating_document:response.data", response.data) def test_deleting_document(self): url = reverse(f"{VERSION_NAMESPACE}:document_instance", kwargs={"pk": self.project.pk, "index": self.test_index_name, "document_id": self.document_id}) response = self.client.delete(url) self.assertTrue(response.status_code == status.HTTP_200_OK) try: self.ec.es.get(id=self.document_id, index=self.test_index_name) except NotFoundError: print_output("test_deleting_document:response.data", response.data) else: raise Exception("Elasticsearch didnt delete a document it should have!") def test_unauthenticated_access(self): self.client.logout() url = reverse(f"{VERSION_NAMESPACE}:document_import", kwargs={"pk": self.project.pk}) response = self.client.post(url, data={"documents": [self.document]}, format="json") print_output("test_unauthenticated_access:response.data", response.data) self.assertTrue(response.status_code == status.HTTP_403_FORBIDDEN or response.status_code == status.HTTP_401_UNAUTHORIZED) def test_adding_document_without_specified_index_and_that_index_is_added_into_project(self): from toolkit.elastic.document_importer.views import DocumentImportView sample_id = 65959645 url = reverse(f"{VERSION_NAMESPACE}:document_import", kwargs={"pk": self.project.pk}) response = self.client.post(url, data={"documents": [{"_source": self.source, "_id": sample_id}]}, format="json") self.assertTrue(response.status_code == status.HTTP_200_OK) normalized_project_title = DocumentImportView.get_new_index_name(self.project.pk) document = self.ec.es.get(id=sample_id, index=normalized_project_title) self.ec.delete_index(normalized_project_title) # Cleanup self.assertTrue(document["_source"]) self.assertTrue(self.project.indices.filter(name=normalized_project_title).exists()) print_output("test_adding_document_without_specified_index_and_that_index_is_added_into_project:response.data", response.data) def test_updating_non_existing_document(self): sample_id = "random_id" url = reverse(f"{VERSION_NAMESPACE}:document_instance", kwargs={"pk": self.project.pk, "index": self.test_index_name, "document_id": sample_id}) response = self.client.patch(url, data={"hello": "world"}, format="json") print_output("test_updating_non_existing_document:response.data", response.data) self.assertTrue(response.status_code == status.HTTP_404_NOT_FOUND) def test_deleting_non_existing_document(self): sample_id = "random_id" url = reverse(f"{VERSION_NAMESPACE}:document_instance", kwargs={"pk": self.project.pk, "index": self.test_index_name, "document_id": sample_id}) response = self.client.delete(url, data={"hello": "world"}, format="json") print_output("test_deleting_non_existing_document:response.data", response.data) self.assertTrue(response.status_code == status.HTTP_404_NOT_FOUND) def test_that_specified_field_is_being_split(self): url = reverse(f"{VERSION_NAMESPACE}:document_import", kwargs={"pk": self.project.pk}) uuid = "456694-asdasdad4-54646ad-asd4a5d" response = self.client.post( url, format="json", data={ "split_text_in_fields": [self.target_field], "documents": [{ "_index": self.test_index_name, "_source": { self.target_field: "Paradna on kohtu alla antud kokkuleppe alusel selles, et tema, 25.10.2003 kell 00.30 koos,...", "uuid": uuid } }]}, ) print_output("test_that_specified_field_is_being_split:response.data", response.data) self.assertTrue(response.status_code == status.HTTP_200_OK) documents = self.ec.es.search(index=self.test_index_name, body={"query": {"term": {"uuid.keyword": uuid}}}) document = documents["hits"]["hits"][0] self.assertTrue(document) self.assertTrue(document["_source"]) self.assertTrue("page" in document["_source"]) def test_that_wrong_field_value_will_skip_splitting(self): url = reverse(f"{VERSION_NAMESPACE}:document_import", kwargs={"pk": self.project.pk}) uuid = "Adios" response = self.client.post( url, format="json", data={ "documents": [{ "_index": self.test_index_name, "_source": { self.target_field: "Paradna on kohtu alla antud kokkuleppe alusel selles, et tema, 25.10.2003 kell 00.30 koos,...", "uuid": uuid } }]}, ) print_output("test_that_empty_field_value_will_skip_splitting:response.data", response.data) self.assertTrue(response.status_code == status.HTTP_200_OK) documents = self.ec.es.search(index=self.test_index_name, body={"query": {"term": {"uuid.keyword": uuid}}}) document = documents["hits"]["hits"][0] self.assertTrue(document) self.assertTrue(document["_source"]) self.assertTrue("page" not in document["_source"]) def test_splitting_behaviour_with_empty_list_as_input(self): url = reverse(f"{VERSION_NAMESPACE}:document_import", kwargs={"pk": self.project.pk}) uuid = "adasdasd-5g465s-fa4s69f4a8s97-a4das9f4" response = self.client.post( url, format="json", data={ "split_text_in_fields": [], "documents": [{ "_index": self.test_index_name, "_source": { self.target_field: "Paradna on kohtu alla antud kokkuleppe alusel selles, et tema, 25.10.2003 kell 00.30 koos,...", "uuid": uuid } }]}, ) print_output("test_splitting_behaviour_with_empty_list_as_input:response.data", response.data) self.assertTrue(response.status_code == status.HTTP_200_OK) documents = self.ec.es.search(index=self.test_index_name, body={"query": {"term": {"uuid.keyword": uuid}}}) document = documents["hits"]["hits"][0] self.assertTrue(document) self.assertTrue(document["_source"]) self.assertTrue("page" not in document["_source"]) def test_that_indexes_do_rollover_after_a_certain_count(self): set_core_setting("TEXTA_ES_MAX_DOCS_PER_INDEX", "50") url = reverse(f"{VERSION_NAMESPACE}:document_import", kwargs={"pk": self.project.pk}) response = self.client.post(url, data={"documents": [{"_source": {"value": i}} for i in range(51)], "split_text_in_fields": []}, format="json") print_output("test_that_indexes_do_rollover_after_a_certain_count:response.data", response.data) self.assertTrue(response.status_code == status.HTTP_200_OK) # Send a batch for the second time to see if the documents rolled over properly. clean_up_container = ["texta-1-import-project-1"] for i in range(1, 12): response = self.client.post(url, data={"documents": [{"_source": {"value": i}} for i in range(51)], "split_text_in_fields": []}, format="json") self.assertTrue(response.status_code == status.HTTP_200_OK) if i == 1: self.assertTrue(self.project.indices.filter(name="texta-1-import-project-1").exists()) else: # i-1 because the first request doesn't add the rollover index counter. rollover_index_name = f"texta-1-import-project-1-{i - 1}" self.assertTrue(self.project.indices.filter(name=rollover_index_name).exists()) clean_up_container.append(rollover_index_name) # Cleanup self.ec.delete_index(index=",".join(clean_up_container))
class TorchTaggerViewTests(APITransactionTestCase): def setUp(self): # Owner of the project self.test_index_name = reindex_test_dataset() self.user = create_test_user('torchTaggerOwner', '*****@*****.**', 'pw') self.project = project_creation('torchTaggerTestProject', self.test_index_name, self.user) self.project.users.add(self.user) self.url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}/torchtaggers/' self.project_url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}' self.test_embedding_id = None self.torch_models = list(TORCH_MODELS.keys()) self.test_tagger_id = None self.test_multiclass_tagger_id = None self.client.login(username='******', password='******') # new fact name and value used when applying tagger to index self.new_fact_name = "TEST_TORCH_TAGGER_NAME" self.new_multiclass_fact_name = "TEST_TORCH_TAGGER_NAME_MC" self.new_fact_value = "TEST_TORCH_TAGGER_VALUE" # Create copy of test index self.reindex_url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}/elastic/reindexer/' # Generate name for new index containing random id to make sure it doesn't already exist self.test_index_copy = f"test_apply_torch_tagger_{uuid.uuid4().hex}" self.reindex_payload = { "description": "test index for applying taggers", "indices": [self.test_index_name], "query": json.dumps(TEST_QUERY), "new_index": self.test_index_copy, "fields": TEST_FIELD_CHOICE } resp = self.client.post(self.reindex_url, self.reindex_payload, format='json') print_output( "reindex test index for applying torch tagger:response.data:", resp.json()) self.reindexer_object = Reindexer.objects.get(pk=resp.json()["id"]) self.ec = ElasticCore() def import_test_model(self, file_path: str): """Import models for testing.""" print_output("Importing model from file:", file_path) files = {"file": open(file_path, "rb")} import_url = f'{self.url}import_model/' resp = self.client.post(import_url, data={ 'file': open(file_path, "rb") }).json() print_output("Importing test model:", resp) return resp["id"] def tearDown(self) -> None: res = self.ec.delete_index(self.test_index_copy) self.ec.delete_index(index=self.test_index_name, ignore=[400, 404]) print_output( f"Delete apply_torch_taggers test index {self.test_index_copy}", res) def test(self): pass # self.run_train_embedding() # self.run_train_tagger_using_query() # self.run_train_torchtagger_without_embedding() # self.run_train_multiclass_tagger_using_fact_name() # self.run_train_balanced_multiclass_tagger_using_fact_name() # self.run_train_binary_multiclass_tagger_using_fact_name() # self.run_train_binary_multiclass_tagger_using_fact_name_invalid_payload() # self.run_tag_text() # self.run_model_export_import() # self.run_tag_with_imported_gpu_model() # were already commented out # self.run_tag_with_imported_cpu_model() # were already commented out # self.run_tag_random_doc() # self.run_epoch_reports_get() # self.run_epoch_reports_post() # self.run_tag_and_feedback_and_retrain() # self.run_apply_binary_tagger_to_index() # self.run_apply_tagger_to_index_invalid_input() def add_cleanup_files(self, tagger_id): tagger_object = TorchTagger.objects.get(pk=tagger_id) self.addCleanup(remove_file, tagger_object.model.path) if not TEST_KEEP_PLOT_FILES: self.addCleanup(remove_file, tagger_object.plot.path) self.addCleanup(remove_file, tagger_object.embedding.embedding_model.path) def run_train_embedding(self): # payload for training embedding payload = { "description": "TestEmbedding", "fields": TEST_FIELD_CHOICE, "max_vocab": 10000, "min_freq": 5, "num_dimensions": 300 } # post embeddings_url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}/embeddings/' response = self.client.post(embeddings_url, payload, format='json') self.test_embedding_id = response.data["id"] print_output("run_train_embedding", 201) def run_train_torchtagger_without_embedding(self): payload = { "description": "TestTorchTaggerTraining", "fields": TEST_FIELD_CHOICE, "maximum_sample_size": 500, "model_architecture": self.torch_models[0], "num_epochs": 3 } response = self.client.post(self.url, payload, format='json') print_output(f"run_train_torchtagger_without_embedding", response.data) self.assertTrue(response.status_code == status.HTTP_400_BAD_REQUEST) def run_train_tagger_using_query(self): """Tests TorchTagger training, and if a new Task gets created via the signal""" payload = { "description": "TestTorchTaggerTraining", "fields": TEST_FIELD_CHOICE, "query": json.dumps(TEST_QUERY), "maximum_sample_size": 500, "model_architecture": self.torch_models[0], "num_epochs": 3, "embedding": self.test_embedding_id, } print_output(f"training tagger with payload: {payload}", 200) response = self.client.post(self.url, payload, format='json') print_output( 'test_create_binary_torchtagger_training_and_task_signal:response.data', response.data) # Check if Neurotagger gets created self.assertEqual(response.status_code, status.HTTP_201_CREATED) # Check if f1 not NULL (train and validation success) tagger_id = response.data['id'] response = self.client.get(f'{self.url}{tagger_id}/') print_output('test_torchtagger_has_stats:response.data', response.data) for score in ['f1_score', 'precision', 'recall', 'accuracy']: self.assertTrue(isinstance(response.data[score], float)) print_output('test_torchtagger_has_classes:response.data.classes', response.data["classes"]) self.assertTrue(isinstance(response.data["classes"], list)) self.assertTrue(len(response.data["classes"]) == 2) self.test_tagger_id = tagger_id # add cleanup self.add_cleanup_files(tagger_id) def run_train_multiclass_tagger_using_fact_name(self): """Tests TorchTagger training with multiple classes and if a new Task gets created via the signal""" payload = { "description": "TestTorchTaggerTraining", "fact_name": TEST_FACT_NAME, "fields": TEST_FIELD_CHOICE, "maximum_sample_size": 500, "model_architecture": self.torch_models[0], "num_epochs": 3, "embedding": self.test_embedding_id, } response = self.client.post(self.url, payload, format='json') print_output( 'test_create_multiclass_torchtagger_training_and_task_signal:response.data', response.data) # Check if Neurotagger gets created self.assertEqual(response.status_code, status.HTTP_201_CREATED) # Check if f1 not NULL (train and validation success) tagger_id = response.data['id'] response = self.client.get(f'{self.url}{tagger_id}/') print_output('test_torchtagger_has_stats:response.data', response.data) for score in ['f1_score', 'precision', 'recall', 'accuracy']: self.assertTrue(isinstance(response.data[score], float)) print_output('test_torchtagger_has_classes:response.data.classes', response.data["classes"]) self.assertTrue(isinstance(response.data["classes"], list)) self.assertTrue(len(response.data["classes"]) > 2) self.test_multiclass_tagger_id = tagger_id # add cleanup self.add_cleanup_files(tagger_id) def run_train_binary_multiclass_tagger_using_fact_name(self): """Tests TorchTagger training with binary facts.""" payload = { "description": "TestBinaryMulticlassTorchTaggerTraining", "fact_name": TEST_FACT_NAME, "fields": TEST_FIELD_CHOICE, "maximum_sample_size": 500, "model_architecture": self.torch_models[0], "num_epochs": 3, "embedding": self.test_embedding_id, "pos_label": TEST_POS_LABEL, "query": json.dumps(TEST_BIN_FACT_QUERY) } response = self.client.post(self.url, payload, format='json') print_output( 'test_create_binary_multiclass_torchtagger_training_and_task_signal:response.data', response.data) # Check if Neurotagger gets created self.assertEqual(response.status_code, status.HTTP_201_CREATED) # Check if f1 not NULL (train and validation success) tagger_id = response.data['id'] response = self.client.get(f'{self.url}{tagger_id}/') print_output('test_torchtagger_has_stats:response.data', response.data) for score in ['f1_score', 'precision', 'recall', 'accuracy']: self.assertTrue(isinstance(response.data[score], float)) print_output('test_torchtagger_has_classes:response.data.classes', response.data["classes"]) self.assertTrue(isinstance(response.data["classes"], list)) self.assertTrue(len(response.data["classes"]) == 2) # add cleanup self.add_cleanup_files(tagger_id) def run_train_binary_multiclass_tagger_using_fact_name_invalid_payload( self): """Tests TorchTagger training with binary facts and invalid payload.""" # Pos label is undefined by the user invalid_payload_1 = { "description": "TestBinaryMulticlassTorchTaggerTrainingMissingPosLabel", "fact_name": TEST_FACT_NAME, "fields": TEST_FIELD_CHOICE, "maximum_sample_size": 500, "model_architecture": self.torch_models[0], "num_epochs": 3, "embedding": self.test_embedding_id, "query": json.dumps(TEST_BIN_FACT_QUERY) } response = self.client.post(self.url, invalid_payload_1, format='json') print_output( 'test_create_binary_multiclass_torchtagger_using_fact_name_missing_pos_label:response.data', response.data) # Check if creating the Tagger fails with status code 400 self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) # The inserted pos label is not present in the data invalid_payload_2 = { "description": "TestBinaryMulticlassTorchTaggerTrainingMissingPosLabel", "fact_name": TEST_FACT_NAME, "fields": TEST_FIELD_CHOICE, "maximum_sample_size": 500, "model_architecture": self.torch_models[0], "num_epochs": 3, "embedding": self.test_embedding_id, "query": json.dumps(TEST_BIN_FACT_QUERY), "pos_label": "invalid_fact_val" } response = self.client.post(self.url, invalid_payload_2, format='json') print_output( 'test_create_binary_multiclass_torchtagger_using_fact_name_invalid_pos_label:response.data', response.data) # Check if creating the Tagger fails with status code 400 self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) def run_train_balanced_multiclass_tagger_using_fact_name(self): """Tests TorchTagger training with multiple balanced classes and if a new Task gets created via the signal""" payload = { "description": "TestBalancedTorchTaggerTraining", "fact_name": TEST_FACT_NAME, "fields": TEST_FIELD_CHOICE, "maximum_sample_size": 150, "model_architecture": self.torch_models[0], "num_epochs": 2, "embedding": self.test_embedding_id, "balance": True, "use_sentence_shuffle": True, "balance_to_max_limit": True } response = self.client.post(self.url, payload, format='json') print_output( 'test_create_balanced_multiclass_torchtagger_training_and_task_signal:response.data', response.data) # Check if Neurotagger gets created self.assertEqual(response.status_code, status.HTTP_201_CREATED) # Check if f1 not NULL (train and validation success) tagger_id = response.data['id'] response = self.client.get(f'{self.url}{tagger_id}/') print_output('test_balanced_torchtagger_has_stats:response.data', response.data) for score in ['f1_score', 'precision', 'recall', 'accuracy']: self.assertTrue(isinstance(response.data[score], float)) num_examples = json.loads(response.data["num_examples"]) print_output( 'test_balanced_torchtagger_num_examples_correct:num_examples', num_examples) for class_size in num_examples.values(): self.assertTrue(class_size, payload["maximum_sample_size"]) print_output('test_balanced_torchtagger_has_classes:classes', response.data["classes"]) self.assertTrue(isinstance(response.data["classes"], list)) self.assertTrue(len(response.data["classes"]) >= 2) # add cleanup self.add_cleanup_files(tagger_id) def run_tag_text(self): """Tests tag prediction for texts.""" payload = {"text": "mine kukele, kala"} response = self.client.post( f'{self.url}{self.test_tagger_id}/tag_text/', payload) print_output('test_torchtagger_tag_text:response.data', response.data) self.assertTrue(isinstance(response.data, dict)) self.assertTrue('result' in response.data) self.assertTrue('probability' in response.data) self.assertTrue('tagger_id' in response.data) def run_tag_random_doc(self): """Tests the endpoint for the tag_random_doc action""" payload = {"indices": [{"name": self.test_index_name}]} url = f'{self.url}{self.test_tagger_id}/tag_random_doc/' response = self.client.post(url, format="json", data=payload) print_output('test_tag_random_doc:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) # Check if response is list self.assertTrue(isinstance(response.data, dict)) self.assertTrue('prediction' in response.data) self.assertTrue('result' in response.data['prediction']) self.assertTrue('probability' in response.data['prediction']) self.assertTrue('tagger_id' in response.data['prediction']) def run_epoch_reports_get(self): """Tests endpoint for retrieving epoch reports via GET""" url = f'{self.url}{self.test_tagger_id}/epoch_reports/' response = self.client.get(url, format="json") print_output('test_torchagger_epoch_reports_get:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) # Check if response is a list self.assertTrue(isinstance(response.data, list)) # Check if first report is not empty self.assertTrue(len(response.data[0]) > 0) def run_epoch_reports_post(self): """Tests endpoint for retrieving epoch reports via GET""" url = f'{self.url}{self.test_tagger_id}/epoch_reports/' payload_1 = {} payload_2 = { "ignore_fields": ["true_positive_rate", "false_positive_rate", "recall"] } response = self.client.post(url, format="json", data=payload_1) print_output( 'test_torchagger_epoch_reports_post_ignore_default:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) # Check if response is a list self.assertTrue(isinstance(response.data, list)) # Check if first report contains recall self.assertTrue("recall" in response.data[0]) response = self.client.post(url, format="json", data=payload_2) print_output( 'test_torchtagger_epoch_reports_post_ignore_custom:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) # Check if response is a list self.assertTrue(isinstance(response.data, list)) # Check if first report does NOT contains recall self.assertTrue("recall" not in response.data[0]) def run_model_export_import(self): """Tests endpoint for model export and import""" test_tagger_group_id = self.test_tagger_id # retrieve model zip url = f'{self.url}{test_tagger_group_id}/export_model/' response = self.client.get(url) # Post model zip import_url = f'{self.url}import_model/' response = self.client.post(import_url, data={'file': BytesIO(response.content)}) tagger_id = response.data['id'] print_output('test_import_model:response.data', import_url) self.assertEqual(response.status_code, status.HTTP_201_CREATED) torchtagger = TorchTagger.objects.get(pk=response.data["id"]) # Check if the models and plot files exist. resources = torchtagger.get_resource_paths() for path in resources.values(): file = pathlib.Path(path) self.assertTrue(file.exists()) # Tests the endpoint for the tag_random_doc action""" url = f'{self.url}{torchtagger.pk}/tag_random_doc/' payload = {"indices": [{"name": self.test_index_name}]} response = self.client.post(url, format='json', data=payload) print_output( 'test_torchtagger_tag_random_doc_after_import:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue(isinstance(response.data, dict)) self.assertTrue('prediction' in response.data) self.assertTrue('result' in response.data['prediction']) self.assertTrue('probability' in response.data['prediction']) self.assertTrue('tagger_id' in response.data['prediction']) self.add_cleanup_files(tagger_id) def run_apply_binary_tagger_to_index(self): """Tests applying binary torch tagger to index using apply_to_index endpoint.""" # Make sure reindexer task has finished while self.reindexer_object.task.status != Task.STATUS_COMPLETED: print_output( 'test_apply_binary_torch_tagger_to_index: waiting for reindexer task to finish, current status:', self.reindexer_object.task.status) sleep(2) url = f'{self.url}{self.test_tagger_id}/apply_to_index/' payload = { "description": "apply torch tagger to index test task", "new_fact_name": self.new_fact_name, "new_fact_value": self.new_fact_value, "indices": [{ "name": self.test_index_copy }], "fields": TEST_FIELD_CHOICE } response = self.client.post(url, payload, format='json') print_output('test_apply_binary_torch_tagger_to_index:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_201_CREATED) tagger_object = TorchTagger.objects.get(pk=self.test_tagger_id) # Wait til the task has finished while tagger_object.task.status != Task.STATUS_COMPLETED: print_output( 'test_apply_binary_torch_tagger_to_index: waiting for applying tagger task to finish, current status:', tagger_object.task.status) sleep(2) results = ElasticAggregator( indices=[self.test_index_copy]).get_fact_values_distribution( self.new_fact_name) print_output( "test_apply_binary_torch_tagger_to_index:elastic aggerator results:", results) # Check if expected number of facts is added self.assertTrue(results[self.new_fact_value] > 10) def run_apply_tagger_to_index_invalid_input(self): """Tests applying multiclass torch tagger to index using apply_to_index endpoint.""" url = f'{self.url}{self.test_tagger_id}/apply_to_index/' payload = { "description": "apply torch tagger to index test task", "new_fact_name": self.new_fact_name, "new_fact_value": self.new_fact_value, "fields": "invalid_field_format" } response = self.client.post(url, payload, format='json') print_output('test_invalid_apply_torch_tagger_to_index:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.add_cleanup_files(self.test_tagger_id) def run_tag_and_feedback_and_retrain(self): """Tests feeback extra action.""" tagger_id = self.test_tagger_id tagger_orm: TorchTagger = TorchTagger.objects.get( pk=self.test_tagger_id) model_path = pathlib.Path(tagger_orm.model.path) print_output( 'run_tag_and_feedback_and_retrain:assert that previous model doesnt exist', data=model_path.exists()) self.assertTrue(model_path.exists()) payload = { "text": "This is some test text for the Tagger Test", "feedback_enabled": True } tag_text_url = f'{self.url}{tagger_id}/tag_text/' response = self.client.post(tag_text_url, payload) print_output('test_tag_text_with_feedback:response.data', response.data) self.assertTrue('feedback' in response.data) # generate feedback fb_id = response.data['feedback']['id'] feedback_url = f'{self.url}{tagger_id}/feedback/' payload = {"feedback_id": fb_id, "correct_result": "FUBAR"} response = self.client.post(feedback_url, payload, format='json') print_output('test_tag_text_with_feedback:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue(response.data) self.assertTrue('success' in response.data) # sleep for a sec to allow elastic to finish its bussiness sleep(1) # list feedback feedback_list_url = f'{self.url}{tagger_id}/feedback/' response = self.client.get(feedback_list_url) print_output('test_tag_text_list_feedback:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue(response.data) self.assertTrue(len(response.data) > 0) # add model files before retraining self.add_cleanup_files(tagger_id) # retrain model url = f'{self.url}{tagger_id}/retrain_tagger/' response = self.client.post(url) print_output('test_feedback:retrain', response.data) # test tagging again for this model payload = {"text": "This is some test text for the Tagger Test"} tag_text_url = f'{self.url}{tagger_id}/tag_text/' response = self.client.post(tag_text_url, payload) print_output('test_feedback_retrained_tag_doc:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue('result' in response.data) self.assertTrue('probability' in response.data) # Ensure that previous tagger is deleted properly. print_output( 'test_model_retrain:assert that previous model doesnt exist', data=model_path.exists()) self.assertFalse(model_path.exists()) # Ensure that the freshly created model wasn't deleted. tagger_orm.refresh_from_db() self.assertNotEqual(tagger_orm.model.path, str(model_path)) # delete feedback feedback_delete_url = f'{self.url}{tagger_id}/feedback/' response = self.client.delete(feedback_delete_url) print_output('test_tag_doc_delete_feedback:response.data', response.data) # sleep for a sec to allow elastic to finish its bussiness sleep(1) # list feedback again to make sure its emtpy feedback_list_url = f'{self.url}{tagger_id}/feedback/' response = self.client.get(feedback_list_url) print_output('test_tag_doc_list_feedback_after_delete:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue(len(response.data) == 0) # remove created index feedback_index_url = f'{self.project_url}/feedback/' response = self.client.delete(feedback_index_url) print_output('test_delete_feedback_index:response.data', response.data) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue('success' in response.data) # add model files after retraining self.add_cleanup_files(tagger_id)
def tearDown(self) -> None: CRFExtractor.objects.all().delete() ec = ElasticCore() ec.delete_index(self.test_index_copy)