Beispiel #1
0
    def test_project_creation_using_locked_indices(self):
        """
        There was a bug case that warrants this test case.
        In case any of the indices in Elasticsearch are locked for whatever reason,
        then project creation fails because it contains/contained  lazy index creation
        which ignores existing indices and thus created FORBIDDEN_ACCESS errors in Elasticsearch.
        """
        index_names = ["locked_index_{}".format(i) for i in range(1, 5)]
        for index in index_names:
            self.__create_locked_index(index)

        response = self.client.post(
            reverse(f"{VERSION_NAMESPACE}:project-list"),
            format="json",
            data={
                "title": "faulty_project",
                "indices_write": index_names,
                "users_write": [self.admin.username]
            })

        ec = ElasticCore()
        for index in index_names:
            ec.delete_index(index)

        self.assertTrue(response.status_code == status.HTTP_201_CREATED)
Beispiel #2
0
 def tearDown(self) -> None:
     Tagger.objects.all().delete()
     ec = ElasticCore()
     res = ec.delete_index(self.test_index_copy)
     ec.delete_index(index=self.test_index_name, ignore=[400, 404])
     print_output(f"Delete apply_taggers test index {self.test_index_copy}",
                  res)
Beispiel #3
0
 def tearDown(self) -> None:
     ec = ElasticCore()
     ec.delete_index(index=self.test_index_name, ignore=[400, 404])
     print_output(f"Delete [Rakun Extractor] test index {self.test_index_name}", None)
     Embedding.objects.all().delete()
     ElasticCore().delete_index(index=self.test_index_name, ignore=[400, 404])
     print_output(f"Delete Rakun FASTTEXT Embeddings", None)
Beispiel #4
0
 def destroy(self, request, pk=None, **kwargs):
     with transaction.atomic():
         index_name = Index.objects.get(pk=pk).name
         es = ElasticCore()
         es.delete_index(index_name)
         Index.objects.filter(pk=pk).delete()
         return Response(
             {"message": f"Deleted index {index_name} from Elasticsearch!"})
Beispiel #5
0
 def bulk_delete(self, request, project_pk=None):
     serializer: IndexBulkDeleteSerializer = self.get_serializer(
         data=request.data)
     serializer.is_valid(raise_exception=True)
     # Initialize Elastic requirements.
     ec = ElasticCore()
     # Get the index names.
     ids = serializer.validated_data["ids"]
     objects = Index.objects.filter(pk__in=ids)
     index_names = [item.name for item in objects]
     # Ensure deletion on both Elastic and DB.
     if index_names:
         ec.delete_index(",".join(index_names))
     deleted = objects.delete()
     info = {"num_deleted": deleted[0], "deleted_types": deleted[1]}
     return Response(info, status=status.HTTP_200_OK)
Beispiel #6
0
class IndexViewsTest(APITestCase):
    @classmethod
    def setUpTestData(cls):
        # Owner of the project
        cls.user = create_test_user('user',
                                    '*****@*****.**',
                                    'pw',
                                    superuser=True)

    def setUp(self) -> None:
        self.client.login(username="******", password="******")
        self.ec = ElasticCore()
        self.ids = []
        self.index_names = [
            "test_for_index_endpoint_1", "test_for_index_endpoint_2"
        ]

        for index_name in self.index_names:
            index, is_created = Index.objects.get_or_create(name=index_name)
            self.ec.es.indices.create(index=index_name, ignore=[400, 404])
            self.ids.append(index.pk)

    def test_bulk_delete(self):
        url = reverse("v2:index-bulk-delete")
        response = self.client.post(url, data={"ids": self.ids}, format="json")
        self.assertTrue(response.status_code == status.HTTP_200_OK)
        for index_name in self.index_names:
            self.assertFalse(self.ec.es.indices.exists(index_name))
        print_output("test_bulk_delete:response.data", response.data)

    def tearDown(self) -> None:
        indices = Index.objects.filter(pk__in=self.ids)
        names = [index.name for index in indices]
        if names:
            indices.delete()
            for index in names:
                self.ec.delete_index(index=index, ignore=[400, 404])
Beispiel #7
0
class BertTaggerObjectViewTests(APITransactionTestCase):
    def setUp(self):
        # Owner of the project
        self.test_index_name = reindex_test_dataset()
        self.user = create_test_user('BertTaggerOwner', '*****@*****.**', 'pw')
        self.admin_user = create_test_user("AdminBertUser",
                                           '*****@*****.**',
                                           'pw',
                                           superuser=True)
        self.project = project_creation("BertTaggerTestProject",
                                        self.test_index_name, self.user)
        self.project.users.add(self.user)
        self.project.users.add(self.admin_user)
        self.url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}/bert_taggers/'
        self.project_url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}'

        self.test_tagger_id = None
        self.test_multiclass_tagger_id = None

        self.client.login(username='******', password='******')

        # Check if TEST_BERT_MODEL is already downloaded
        available_models = get_downloaded_bert_models(
            BERT_PRETRAINED_MODEL_DIRECTORY)
        self.test_model_existed = True if TEST_BERT_MODEL in available_models else False
        download_bert_requirements(BERT_PRETRAINED_MODEL_DIRECTORY,
                                   [TEST_BERT_MODEL],
                                   cache_directory=BERT_CACHE_DIR,
                                   num_labels=2)

        # new fact name and value used when applying tagger to index
        self.new_fact_name = "TEST_BERT_TAGGER_NAME"
        self.new_multiclass_fact_name = "TEST_BERT_TAGGER_NAME_MC"
        self.new_fact_value = "TEST_BERT_TAGGER_VALUE"

        # Create copy of test index
        self.reindex_url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}/elastic/reindexer/'
        # Generate name for new index containing random id to make sure it doesn't already exist
        self.test_index_copy = f"test_apply_bert_tagger_{uuid.uuid4().hex}"

        self.reindex_payload = {
            "description": "test index for applying taggers",
            "indices": [self.test_index_name],
            "query": json.dumps(TEST_QUERY),
            "new_index": self.test_index_copy,
            "fields": TEST_FIELD_CHOICE
        }
        resp = self.client.post(self.reindex_url,
                                self.reindex_payload,
                                format='json')
        print_output(
            "reindex test index for applying bert tagger:response.data:",
            resp.json())
        self.reindexer_object = Reindexer.objects.get(pk=resp.json()["id"])

        self.test_imported_binary_gpu_tagger_id = self.import_test_model(
            TEST_BERT_TAGGER_BINARY_GPU)
        self.test_imported_multiclass_gpu_tagger_id = self.import_test_model(
            TEST_BERT_TAGGER_MULTICLASS_GPU)

        self.test_imported_binary_cpu_tagger_id = self.import_test_model(
            TEST_BERT_TAGGER_BINARY_CPU)
        self.ec = ElasticCore()

    def import_test_model(self, file_path: str):
        """Import fine-tuned models for testing."""
        print_output("Importing model from file:", file_path)
        files = {"file": open(file_path, "rb")}
        import_url = f'{self.url}import_model/'
        resp = self.client.post(import_url,
                                data={
                                    'file': open(file_path, "rb")
                                }).json()
        print_output("Importing test model:", resp)
        return resp["id"]

    def test(self):
        self.run_train_multiclass_bert_tagger_using_fact_name()
        self.run_train_balanced_multiclass_bert_tagger_using_fact_name()
        self.run_train_bert_tagger_from_checkpoint_model_bin2bin()
        self.run_train_bert_tagger_from_checkpoint_model_bin2mc()
        self.run_train_binary_multiclass_bert_tagger_using_fact_name()
        self.run_train_binary_multiclass_bert_tagger_using_fact_name_invalid_payload(
        )
        self.run_train_bert_tagger_using_query()
        self.run_bert_tag_text()
        self.run_bert_tag_with_imported_gpu_model()
        self.run_bert_tag_with_imported_cpu_model()
        self.run_bert_tag_random_doc()
        self.run_bert_epoch_reports_get()
        self.run_bert_epoch_reports_post()
        self.run_bert_get_available_models()
        self.run_bert_download_pretrained_model()
        self.run_bert_tag_and_feedback_and_retrain()
        self.run_bert_model_export_import()
        self.run_apply_binary_tagger_to_index()
        self.run_apply_multiclass_tagger_to_index()
        self.run_apply_tagger_to_index_invalid_input()
        self.run_bert_tag_text_persistent()

        self.run_test_that_user_cant_delete_pretrained_model()
        self.run_test_that_admin_users_can_delete_pretrained_model()

        self.add_cleanup_files(self.test_tagger_id)
        self.add_cleanup_folders()

    def tearDown(self) -> None:
        res = self.ec.delete_index(self.test_index_copy)
        self.ec.delete_index(index=self.test_index_name, ignore=[400, 404])
        print_output(
            f"Delete apply_bert_taggers test index {self.test_index_copy}",
            res)

    def add_cleanup_files(self, tagger_id: int):
        tagger_object = BertTaggerObject.objects.get(pk=tagger_id)
        self.addCleanup(remove_file, tagger_object.model.path)
        if not TEST_KEEP_PLOT_FILES:
            self.addCleanup(remove_file, tagger_object.plot.path)

    def add_cleanup_folders(self):
        if not self.test_model_existed:
            test_model_dir = os.path.join(
                BERT_PRETRAINED_MODEL_DIRECTORY,
                BertTagger.normalize_name(TEST_BERT_MODEL))
            self.addCleanup(remove_folder, test_model_dir)

    def run_train_multiclass_bert_tagger_using_fact_name(self):
        """Tests BertTagger training with multiple classes and if a new Task gets created via the signal."""
        payload = {
            "description": "TestBertTaggerObjectTraining",
            "fact_name": TEST_FACT_NAME,
            "fields": TEST_FIELD_CHOICE,
            "indices": [{
                "name": self.test_index_name
            }],
            "maximum_sample_size": 500,
            "num_epochs": 2,
            "max_length": 15,
            "bert_model": TEST_BERT_MODEL
        }
        response = self.client.post(self.url, payload, format='json')

        print_output(
            'test_create_multiclass_bert_tagger_training_and_task_signal:response.data',
            response.data)
        # Check if BertTagger gets created
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
        # Give the tagger some time to finish training
        sleep(5)
        tagger_id = response.data['id']
        self.test_multiclass_tagger_id = tagger_id
        response = self.client.get(f'{self.url}{tagger_id}/')
        print_output('test_multiclass_bert_tagger_has_stats:response.data',
                     response.data)
        for score in ['f1_score', 'precision', 'recall', 'accuracy']:
            self.assertTrue(isinstance(response.data[score], float))

        print_output(
            'test_multiclass_bert_tagger_has_classes:response.data.classes',
            response.data["classes"])
        self.assertTrue(isinstance(response.data["classes"], list))
        self.assertTrue(len(response.data["classes"]) > 2)

        self.add_cleanup_files(tagger_id)

    def run_train_binary_multiclass_bert_tagger_using_fact_name(self):
        """Tests BertTagger training with binary facts."""
        payload = {
            "description": "Test Bert Tagger training binary multiclass",
            "fact_name": TEST_FACT_NAME,
            "fields": TEST_FIELD_CHOICE,
            "pos_label": TEST_POS_LABEL,
            "query": json.dumps(TEST_BIN_FACT_QUERY),
            "indices": [{
                "name": self.test_index_name
            }],
            "maximum_sample_size": 50,
            "num_epochs": 1,
            "max_length": 15,
            "bert_model": TEST_BERT_MODEL
        }
        response = self.client.post(self.url, payload, format='json')

        print_output(
            'test_run_train_binary_multiclass_bert_tagger_using_fact_name:response.data',
            response.data)
        # Check if BertTagger gets created
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
        # Give the tagger some time to finish training
        sleep(5)
        tagger_id = response.data['id']
        response = self.client.get(f'{self.url}{tagger_id}/')
        print_output(
            'test_binary_multiclass_bert_tagger_has_stats:response.data',
            response.data)
        for score in ['f1_score', 'precision', 'recall', 'accuracy']:
            self.assertTrue(isinstance(response.data[score], float))

        print_output(
            'test_binary_multiclass_bert_tagger_has_classes:response.data.classes',
            response.data["classes"])
        self.assertTrue(isinstance(response.data["classes"], list))
        self.assertTrue(len(response.data["classes"]) == 2)

        self.add_cleanup_files(tagger_id)

    def run_train_binary_multiclass_bert_tagger_using_fact_name_invalid_payload(
            self):
        """Tests BertTagger training with binary facts."""
        # Pos label is undefined by the user
        invalid_payload_1 = {
            "description":
            "Test Bert Tagger training binary multiclass invalid",
            "fact_name": TEST_FACT_NAME,
            "fields": TEST_FIELD_CHOICE,
            "query": json.dumps(TEST_BIN_FACT_QUERY),
            "indices": [{
                "name": self.test_index_name
            }],
            "maximum_sample_size": 50,
            "num_epochs": 1,
            "max_length": 15,
            "bert_model": TEST_BERT_MODEL
        }
        response = self.client.post(self.url, invalid_payload_1, format='json')

        print_output(
            'test_run_train_binary_multiclass_bert_tagger_using_fact_name_missing_pos_label:response.data',
            response.data)
        # Check if creating the BertTagger fails with status code 400
        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

        # The pos label the user has inserted is not present in the data
        invalid_payload_2 = {
            "description":
            "Test Bert Tagger training binary multiclass invalid",
            "fact_name": TEST_FACT_NAME,
            "pos_label": "invalid_fact_val",
            "fields": TEST_FIELD_CHOICE,
            "query": json.dumps(TEST_BIN_FACT_QUERY),
            "indices": [{
                "name": self.test_index_name
            }],
            "maximum_sample_size": 50,
            "num_epochs": 1,
            "max_length": 15,
            "bert_model": TEST_BERT_MODEL
        }
        response = self.client.post(self.url, invalid_payload_2, format='json')

        print_output(
            'test_run_train_binary_multiclass_bert_tagger_using_fact_name_invalid_pos_label:response.data',
            response.data)
        # Check if creating the BertTagger fails with status code 400
        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

    def run_train_balanced_multiclass_bert_tagger_using_fact_name(self):
        """Tests balanced BertTagger training with multiple classes and if a new Task gets created via the signal."""
        payload = {
            "description": "TestBalancedBertTaggerObjectTraining",
            "fact_name": TEST_FACT_NAME,
            "fields": TEST_FIELD_CHOICE,
            "indices": [{
                "name": self.test_index_name
            }],
            "maximum_sample_size": 500,
            "num_epochs": 2,
            "max_length": 15,
            "bert_model": TEST_BERT_MODEL,
            "balance": True,
            "use_sentence_shuffle": True,
            "balance_to_max_limit": True
        }
        response = self.client.post(self.url, payload, format='json')

        print_output(
            'test_create_balanced_multiclass_bert_tagger_training_and_task_signal:response.data',
            response.data)
        # Check if BertTagger gets created
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
        # Give the tagger some time to finish training
        sleep(5)
        tagger_id = response.data['id']
        self.test_multiclass_tagger_id = tagger_id
        response = self.client.get(f'{self.url}{tagger_id}/')
        print_output(
            'test_balanced_multiclass_bert_tagger_has_stats:response.data',
            response.data)
        for score in ['f1_score', 'precision', 'recall', 'accuracy']:
            self.assertTrue(isinstance(response.data[score], float))

        print_output(
            'test_balanced_multiclass_bert_tagger_has_classes:response.data.classes',
            response.data["classes"])
        self.assertTrue(isinstance(response.data["classes"], list))
        self.assertTrue(len(response.data["classes"]) >= 2)

        num_examples = json.loads(response.data["num_examples"])
        print_output(
            'test_balanced_bert_tagger_num_examples_correct:num_examples',
            num_examples)
        for class_size in num_examples.values():
            self.assertTrue(class_size, payload["maximum_sample_size"])

        self.add_cleanup_files(tagger_id)

    def run_train_bert_tagger_using_query(self):
        """Tests BertTagger training, and if a new Task gets created via the signal."""
        payload = {
            "description": "TestBertTaggerTraining",
            "fields": TEST_FIELD_CHOICE,
            "query": json.dumps(TEST_QUERY),
            "maximum_sample_size": 500,
            "indices": [{
                "name": self.test_index_name
            }],
            "num_epochs": 2,
            "max_length": 15,
            "bert_model": TEST_BERT_MODEL
        }

        print_output(f"training tagger with payload: {payload}", 200)
        response = self.client.post(self.url, payload, format='json')
        print_output(
            'test_create_binary_bert_tagger_training_and_task_signal:response.data',
            response.data)

        # Check if BertTagger gets created
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
        # Give the tagger some time to finish training
        sleep(5)
        tagger_id = response.data['id']
        response = self.client.get(f'{self.url}{tagger_id}/')
        print_output('test_binary_bert_tagger_has_stats:response.data',
                     response.data)
        for score in ['f1_score', 'precision', 'recall', 'accuracy']:
            self.assertTrue(isinstance(response.data[score], float))

        print_output(
            'test_binary_bert_tagger_has_classes:response.data.classes',
            response.data["classes"])
        self.assertTrue(isinstance(response.data["classes"], list))
        self.assertTrue(len(response.data["classes"]) == 2)

        # set trained tagger as active tagger
        self.test_tagger_id = tagger_id
        self.add_cleanup_files(tagger_id)

    def run_train_bert_tagger_from_checkpoint_model_bin2bin(self):
        """Tests training BertTagger from a checkpoint."""
        payload = {
            "description": "Test training binary BertTagger from checkpoint",
            "fields": TEST_FIELD_CHOICE,
            "query": json.dumps(TEST_QUERY),
            "maximum_sample_size": 500,
            "indices": [{
                "name": self.test_index_name
            }],
            "num_epochs": 2,
            "max_length": 12,
            "bert_model": TEST_BERT_MODEL,
            "checkpoint_model": self.test_tagger_id
        }

        print_output(
            f"training binary bert tagger from checkpoint with payload: ",
            payload)
        response = self.client.post(self.url, payload, format='json')
        print_output(
            'test_train_bert_tagger_from_checkpoint_model_bin2bin:POST:response.data',
            response.data)

        # Check if BertTagger gets created
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
        # Give the tagger some time to finish training
        sleep(5)
        tagger_id = response.data['id']
        response = self.client.get(f'{self.url}{tagger_id}/')
        print_output(
            'test_train_bert_tagger_from_checkpoint_model_bin2bin.has_stats:response.data',
            response.data)
        for score in ['f1_score', 'precision', 'recall', 'accuracy']:
            self.assertTrue(isinstance(response.data[score], float))

        print_output(
            'test_train_bert_tagger_from_checkpoint_model.has_stats:response.data.classes',
            response.data["classes"])
        self.assertTrue(isinstance(response.data["classes"], list))
        self.assertTrue(len(response.data["classes"]) >= 2)

        self.add_cleanup_files(tagger_id)

    def run_train_bert_tagger_from_checkpoint_model_bin2mc(self):
        """Tests training BertTagger from a checkpoint."""
        payload = {
            "description":
            "Test training multiclass BertTagger from checkpoint",
            "fields": TEST_FIELD_CHOICE,
            "fact_name": TEST_FACT_NAME,
            "maximum_sample_size": 500,
            "indices": [{
                "name": self.test_index_name
            }],
            "num_epochs": 2,
            "max_length": 12,
            "bert_model": TEST_BERT_MODEL,
            "checkpoint_model": self.test_tagger_id
        }

        print_output(
            f"training multiclass bert tagger from checkpoint with payload: ",
            payload)
        response = self.client.post(self.url, payload, format='json')
        print_output(
            'test_train_bert_tagger_from_checkpoint_model_bin2mc:POST:response.data',
            response.data)

        # Check if BertTagger gets created
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
        # Give the tagger some time to finish training
        sleep(5)
        tagger_id = response.data['id']
        response = self.client.get(f'{self.url}{tagger_id}/')
        print_output(
            'test_train_bert_tagger_from_checkpoint_model_bin2mc.has_stats:response.data',
            response.data)
        for score in ['f1_score', 'precision', 'recall', 'accuracy']:
            self.assertTrue(isinstance(response.data[score], float))

        print_output(
            'test_train_bert_tagger_from_checkpoint_model_bin2mc.has_classes:response.data.classes',
            response.data["classes"])
        self.assertTrue(isinstance(response.data["classes"], list))
        self.assertTrue(len(response.data["classes"]) >= 2)

        self.add_cleanup_files(tagger_id)

    def run_bert_tag_with_imported_gpu_model(self):
        """Test applying imported model trained on GPU."""
        payload = {"text": "mine kukele, loll"}
        response = self.client.post(
            f'{self.url}{self.test_imported_binary_gpu_tagger_id}/tag_text/',
            payload)
        print_output(
            'test_bert_tagger_tag_with_imported_gpu_model:response.data',
            response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue("probability" in response.data)
        self.assertTrue("result" in response.data)
        self.assertTrue("tagger_id" in response.data)
        # Check if tagger learned to predict
        self.assertEqual("true", response.data["result"])

        self.add_cleanup_files(self.test_imported_binary_gpu_tagger_id)

    def run_bert_tag_with_imported_cpu_model(self):
        """Tests applying imported model trained on CPU."""
        payload = {"text": "mine kukele, loll"}
        response = self.client.post(
            f'{self.url}{self.test_imported_binary_cpu_tagger_id}/tag_text/',
            payload)
        print_output(
            'test_bert_tagger_tag_with_imported_cpu_model:response.data',
            response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue("probability" in response.data)
        self.assertTrue("result" in response.data)
        self.assertTrue("tagger_id" in response.data)
        # Check if tagger learned to predict
        self.assertEqual("true", response.data["result"])

        self.add_cleanup_files(self.test_imported_binary_cpu_tagger_id)

    def run_bert_tag_text(self):
        """Tests tag prediction for texts."""
        payload = {"text": "mine kukele, loll"}
        response = self.client.post(
            f'{self.url}{self.test_tagger_id}/tag_text/', payload)
        print_output('test_bert_tagger_tag_text:response.data', response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue("probability" in response.data)
        self.assertTrue("result" in response.data)
        self.assertTrue("tagger_id" in response.data)

    def run_bert_tag_random_doc(self):
        """Tests the endpoint for the tag_random_doc action"""
        # Tag with specified fields
        payload = {
            "indices": [{
                "name": self.test_index_name
            }],
            "fields": TEST_FIELD_CHOICE
        }
        url = f'{self.url}{self.test_tagger_id}/tag_random_doc/'
        response = self.client.post(url, format="json", data=payload)
        print_output('test_bert_tagger_tag_random_doc:response.data',
                     response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        # Check if response is a dict
        self.assertTrue(isinstance(response.data, dict))
        self.assertTrue("prediction" in response.data)
        self.assertTrue("document" in response.data)
        self.assertTrue("probability" in response.data["prediction"])
        self.assertTrue("result" in response.data["prediction"])
        self.assertTrue("tagger_id" in response.data["prediction"])

        # Tag with unspecified fields
        payload = {"indices": [{"name": self.test_index_name}]}
        url = f'{self.url}{self.test_tagger_id}/tag_random_doc/'
        response = self.client.post(url, format="json", data=payload)
        print_output('test_bert_tagger_tag_random_doc:response.data',
                     response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        # Check if response is a dict
        self.assertTrue(isinstance(response.data, dict))
        self.assertTrue("prediction" in response.data)
        self.assertTrue("document" in response.data)
        self.assertTrue("probability" in response.data["prediction"])
        self.assertTrue("result" in response.data["prediction"])
        self.assertTrue("tagger_id" in response.data["prediction"])

    def run_bert_epoch_reports_get(self):
        """Tests endpoint for retrieving epoch reports via GET"""
        url = f'{self.url}{self.test_tagger_id}/epoch_reports/'
        response = self.client.get(url, format="json")
        print_output('test_bert_tagger_epoch_reports_get:response.data',
                     response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        # Check if response is a list
        self.assertTrue(isinstance(response.data, list))
        # Check if first report is not empty
        self.assertTrue(len(response.data[0]) > 0)

    def run_bert_epoch_reports_post(self):
        """Tests endpoint for retrieving epoch reports via GET"""
        url = f'{self.url}{self.test_tagger_id}/epoch_reports/'
        payload_1 = {}
        payload_2 = {
            "ignore_fields":
            ["true_positive_rate", "false_positive_rate", "recall"]
        }

        response = self.client.post(url, format="json", data=payload_1)
        print_output(
            'test_bert_tagger_epoch_reports_post_ignore_default:response.data',
            response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)

        # Check if response is a list
        self.assertTrue(isinstance(response.data, list))
        # Check if first report contains recall
        self.assertTrue("recall" in response.data[0])

        response = self.client.post(url, format="json", data=payload_2)
        print_output(
            'test_bert_tagger_epoch_reports_post_ignore_custom:response.data',
            response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        # Check if response is a list
        self.assertTrue(isinstance(response.data, list))
        # Check if first report does NOT contains recall
        self.assertTrue("recall" not in response.data[0])

    def run_bert_get_available_models(self):
        """Test endpoint for retrieving available BERT models."""
        url = f'{self.url}available_models/'
        response = self.client.get(url, format="json")
        print_output('test_bert_tagger_get_available_models:response.data',
                     response.data)
        available_models = get_downloaded_bert_models(
            BERT_PRETRAINED_MODEL_DIRECTORY)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        # Check if the endpoint returns currently available models
        self.assertCountEqual(response.data, available_models)

    def run_bert_download_pretrained_model(self):
        """Test endpoint for downloading pretrained BERT model."""
        self.client.login(username="******", password='******')
        url = f'{self.url}download_pretrained_model/'
        # Test endpoint with valid payload
        valid_payload = {"bert_model": "prajjwal1/bert-tiny"}
        response = self.client.post(url, format="json", data=valid_payload)
        print_output(
            'test_bert_tagger_download_pretrained_model_valid_input:response.data',
            response.data)
        if ALLOW_BERT_MODEL_DOWNLOADS:
            self.assertEqual(response.status_code, status.HTTP_200_OK)
        else:
            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

        # Test endpoint with invalid payload
        invalid_payload = {"bert_model": "foo"}
        response = self.client.post(url, format="json", data=invalid_payload)
        print_output(
            'test_bert_tagger_download_pretrained_model_invalid_input:response.data',
            response.data)

        # The endpoint should throw and error
        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

    def run_bert_model_export_import(self):
        """Tests endpoint for model export and import"""
        # test_tagger_id = self.test_tagger_id

        # retrieve model zip
        url = f'{self.url}{self.test_tagger_id}/export_model/'
        response = self.client.get(url)

        # Post model zip
        import_url = f'{self.url}import_model/'
        response = self.client.post(import_url,
                                    data={'file': BytesIO(response.content)})
        print_output('test_bert_import_model:response.data', import_url)

        self.assertEqual(response.status_code, status.HTTP_201_CREATED)

        bert_tagger = BertTaggerObject.objects.get(pk=response.data["id"])
        tagger_id = response.data['id']

        # Check if the models and plot files exist.
        resources = bert_tagger.get_resource_paths()
        for path in resources.values():
            file = pathlib.Path(path)
            self.assertTrue(file.exists())

        # Tests the endpoint for the tag_random_doc action"""
        url = f'{self.url}{bert_tagger.pk}/tag_random_doc/'
        random_doc_payload = {
            "indices": [{
                "name": self.test_index_name
            }],
            "fields": TEST_FIELD_CHOICE
        }
        response = self.client.post(url,
                                    data=random_doc_payload,
                                    format="json")
        print_output(
            'test_bert_tag_random_doc_after_model_import:response.data',
            response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue(isinstance(response.data, dict))

        self.assertTrue('prediction' in response.data)
        # remove exported tagger files
        self.add_cleanup_files(tagger_id)

    def run_apply_binary_tagger_to_index(self):
        """Tests applying binary BERT tagger to index using apply_to_index endpoint."""
        # Make sure reindexer task has finished
        while self.reindexer_object.task.status != Task.STATUS_COMPLETED:
            print_output(
                'test_apply_binary_bert_tagger_to_index: waiting for reindexer task to finish, current status:',
                self.reindexer_object.task.status)
            sleep(2)

        url = f'{self.url}{self.test_imported_binary_gpu_tagger_id}/apply_to_index/'

        payload = {
            "description": "apply bert tagger to index test task",
            "new_fact_name": self.new_fact_name,
            "new_fact_value": self.new_fact_value,
            "indices": [{
                "name": self.test_index_copy
            }],
            "fields": TEST_FIELD_CHOICE
        }
        response = self.client.post(url, payload, format='json')
        print_output('test_apply_binary_bert_tagger_to_index:response.data',
                     response.data)
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
        tagger_object = BertTaggerObject.objects.get(
            pk=self.test_imported_binary_gpu_tagger_id)

        # Wait til the task has finished
        while tagger_object.task.status != Task.STATUS_COMPLETED:
            print_output(
                'test_apply_binary_bert_tagger_to_index: waiting for applying tagger task to finish, current status:',
                tagger_object.task.status)
            sleep(2)

        results = ElasticAggregator(
            indices=[self.test_index_copy]).get_fact_values_distribution(
                self.new_fact_name)
        print_output(
            "test_apply_binary_bert_tagger_to_index:elastic aggerator results:",
            results)

        # Check if the expected number of facts are added to the index
        expected_number_of_facts = 29
        self.assertTrue(
            results[self.new_fact_value] == expected_number_of_facts)

        self.add_cleanup_files(self.test_imported_binary_gpu_tagger_id)

    def run_apply_multiclass_tagger_to_index(self):
        """Tests applying multiclass BERT tagger to index using apply_to_index endpoint."""
        # Make sure reindexer task has finished
        while self.reindexer_object.task.status != Task.STATUS_COMPLETED:
            print_output(
                'test_apply_multiclass_bert_tagger_to_index: waiting for reindexer task to finish, current status:',
                self.reindexer_object.task.status)
            sleep(2)

        url = f'{self.url}{self.test_imported_multiclass_gpu_tagger_id}/apply_to_index/'

        payload = {
            "description": "apply bert tagger to index test task",
            "new_fact_name": self.new_multiclass_fact_name,
            "new_fact_value": self.new_fact_value,
            "indices": [{
                "name": self.test_index_copy
            }],
            "fields": TEST_FIELD_CHOICE
        }
        response = self.client.post(url, payload, format='json')
        print_output(
            'test_apply_multiclass_bert_tagger_to_index:response.data',
            response.data)
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
        tagger_object = BertTaggerObject.objects.get(
            pk=self.test_imported_multiclass_gpu_tagger_id)

        # Wait til the task has finished
        while tagger_object.task.status != Task.STATUS_COMPLETED:
            print_output(
                'test_apply_multiclass_bert_tagger_to_index: waiting for applying tagger task to finish, current status:',
                tagger_object.task.status)
            sleep(2)

        results = ElasticAggregator(
            indices=[self.test_index_copy]).get_fact_values_distribution(
                self.new_multiclass_fact_name)
        print_output(
            "test_apply_multiclass_bert_tagger_to_index:elastic aggerator results:",
            results)

        # Check if the expected facts and the expected number of them are added to the index
        expected_fact_value = "bar"
        expected_number_of_facts = 30
        self.assertTrue(expected_fact_value in results)
        self.assertTrue(
            results[expected_fact_value] == expected_number_of_facts)

        self.add_cleanup_files(self.test_imported_multiclass_gpu_tagger_id)

    def run_apply_tagger_to_index_invalid_input(self):
        """Tests applying multiclass BERT tagger to index using apply_to_index endpoint."""

        url = f'{self.url}{self.test_tagger_id}/apply_to_index/'

        payload = {
            "description": "apply bert tagger to index test task",
            "new_fact_name": self.new_fact_name,
            "new_fact_value": self.new_fact_value,
            "fields": "invalid_field_format",
            "bulk_size": 100,
            "query": json.dumps(TEST_QUERY)
        }
        response = self.client.post(url, payload, format='json')
        print_output('test_invalid_apply_bert_tagger_to_index:response.data',
                     response.data)
        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

        self.add_cleanup_files(self.test_tagger_id)

    def run_bert_tag_and_feedback_and_retrain(self):
        """Tests feeback extra action."""

        # Get basic information to check for previous tagger deletion after retraining has finished.
        tagger_orm: BertTaggerObject = BertTaggerObject.objects.get(
            pk=self.test_tagger_id)
        model_path = pathlib.Path(tagger_orm.model.path)
        print_output(
            'run_bert_tag_and_feedback_and_retrain:assert that previous model doesnt exist',
            data=model_path.exists())
        self.assertTrue(model_path.exists())

        payload = {
            "text": "This is some test text for the Tagger Test",
            "feedback_enabled": True
        }
        tag_text_url = f'{self.url}{self.test_tagger_id}/tag_text/'
        response = self.client.post(tag_text_url, payload)
        print_output('test_bert_tag_text_with_feedback:response.data',
                     response.data)
        self.assertTrue('feedback' in response.data)

        # generate feedback
        fb_id = response.data['feedback']['id']
        feedback_url = f'{self.url}{self.test_tagger_id}/feedback/'
        payload = {"feedback_id": fb_id, "correct_result": "FUBAR"}
        response = self.client.post(feedback_url, payload, format='json')
        print_output('test_bert_tag_text_feedback:response.data',
                     response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue(response.data)
        self.assertTrue('success' in response.data)
        # sleep for a sec to allow elastic to finish its bussiness
        sleep(1)
        # list feedback
        feedback_list_url = f'{self.url}{self.test_tagger_id}/feedback/'
        response = self.client.get(feedback_list_url)
        print_output('test_tag_text_list_feedback:response.data',
                     response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue(response.data)
        self.assertTrue(len(response.data) > 0)

        # retrain model
        url = f'{self.url}{self.test_tagger_id}/retrain_tagger/'
        response = self.client.post(url)
        print_output('test_bert_tagger_feedback:retrain', response.data)
        # test tagging again for this model
        payload = {"text": "This is some test text for the Tagger Test"}
        tag_text_url = f'{self.url}{self.test_tagger_id}/tag_text/'
        response = self.client.post(tag_text_url, payload)
        print_output(
            'test_bert_tagger_feedback_retrained_tag_doc:response.data',
            response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue('result' in response.data)
        self.assertTrue('probability' in response.data)

        # Ensure that previous tagger is deleted properly.
        print_output(
            'test_model_retrain:assert that previous model doesnt exist',
            data=model_path.exists())
        self.assertFalse(model_path.exists())
        # Ensure that the freshly created model wasn't deleted.
        tagger_orm.refresh_from_db()
        self.assertNotEqual(tagger_orm.model.path, str(model_path))

        # delete feedback
        feedback_delete_url = f'{self.url}{self.test_tagger_id}/feedback/'
        response = self.client.delete(feedback_delete_url)
        print_output('test_bert_tagger_tag_doc_delete_feedback:response.data',
                     response.data)
        # sleep for a sec to allow elastic to finish its bussiness
        sleep(1)
        # list feedback again to make sure its emtpy
        feedback_list_url = f'{self.url}{self.test_tagger_id}/feedback/'
        response = self.client.get(feedback_list_url)
        print_output(
            'test_bert_tagger_tag_doc_list_feedback_after_delete:response.data',
            response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue(len(response.data) == 0)
        # remove created index
        feedback_index_url = f'{self.project_url}/feedback/'
        response = self.client.delete(feedback_index_url)
        print_output('test_bert_tagger_delete_feedback_index:response.data',
                     response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue('success' in response.data)

        self.add_cleanup_files(self.test_tagger_id)

    def run_bert_tag_text_persistent(self):
        """Tests tag prediction for texts using persistent models."""
        payload = {"text": "mine kukele, loll", "persistent": True}
        # First try
        start_1 = time_ns()
        response = self.client.post(
            f'{self.url}{self.test_tagger_id}/tag_text/', payload)
        end_1 = time_ns() - start_1

        # Give time for persistent taggers to load
        sleep(1)

        # Second try
        start_2 = time_ns()
        response = self.client.post(
            f'{self.url}{self.test_tagger_id}/tag_text/', payload)
        end_2 = time_ns() - start_2
        # Test if second attempt faster
        print_output('test_bert_tagger_persistent speed:', (end_1, end_2))
        assert end_2 < end_1

    def run_test_that_user_cant_delete_pretrained_model(self):
        self.client.login(username='******', password='******')

        url = reverse("v2:bert_tagger-delete-pretrained-model",
                      kwargs={"project_pk": self.project.pk})
        resp = self.client.post(url,
                                data={"model_name": "EMBEDDIA/finest-bert"})
        print_output(
            "run_test_that_user_cant_delete_pretrained_model:response.data",
            data=resp.data)
        self.assertTrue(resp.status_code == status.HTTP_401_UNAUTHORIZED
                        or resp.status_code == status.HTTP_403_FORBIDDEN)

    def run_test_that_admin_users_can_delete_pretrained_model(self):
        self.client.login(username="******", password='******')

        url = reverse("v2:bert_tagger-delete-pretrained-model",
                      kwargs={"project_pk": self.project.pk})
        model_name = "prajjwal1/bert-tiny"
        file_name = BertTagger.normalize_name(model_name)
        model_path = pathlib.Path(BERT_PRETRAINED_MODEL_DIRECTORY) / file_name
        self.assertTrue(model_path.exists() is True)
        resp = self.client.post(url, data={"model_name": model_name})
        print_output(
            "run_test_that_admin_users_can_delete_pretrained_model:response.data",
            data=resp.data)
        self.assertTrue(resp.status_code == status.HTTP_200_OK)
        self.assertTrue(model_path.exists() is False)
Beispiel #8
0
class ReindexerViewTests(APITransactionTestCase):
    def setUp(self):
        """ user needs to be admin, because of changed indices permissions """
        self.test_index_name = reindex_test_dataset()
        self.default_password = '******'
        self.default_username = '******'
        self.user = create_test_user(self.default_username, '*****@*****.**',
                                     self.default_password)

        # create admin to test indices removal from project
        self.admin = create_test_user(name='admin', password='******')
        self.admin.is_superuser = True
        self.admin.save()
        self.project = project_creation("ReindexerTestProject",
                                        self.test_index_name, self.user)
        self.project.users.add(self.user)
        self.ec = ElasticCore()
        self.client.login(username=self.default_username,
                          password=self.default_password)

        self.new_index_name = f"{TEST_FIELD}_2"

    def tearDown(self) -> None:
        self.ec.delete_index(index=self.test_index_name, ignore=[400, 404])

    def test_run(self):
        existing_new_index_payload = {
            "description": "TestWrongField",
            "indices": [self.test_index_name],
            "new_index":
            REINDEXER_TEST_INDEX,  # index created for test purposes
        }
        wrong_fields_payload = {
            "description": "TestWrongField",
            "indices": [self.test_index_name],
            "new_index": TEST_INDEX_REINDEX,
            "fields": ['12345'],
        }
        wrong_indices_payload = {
            "description": "TestWrongIndex",
            "indices": ["Wrong_Index"],
            "new_index": TEST_INDEX_REINDEX,
        }
        pick_fields_payload = {
            "description": "TestManyReindexerFields",
            # this has a problem with possible name duplicates
            "fields":
            [TEST_FIELD, 'comment_content_clean.text', 'texta_facts'],
            "indices": [self.test_index_name],
            "new_index": TEST_INDEX_REINDEX,
        }
        # duplicate name problem?
        # if you want to actually test it, add an index to indices and project indices
        join_indices_fields_payload = {
            "description": "TestReindexerJoinFields",
            "indices": [self.test_index_name],
            "new_index": TEST_INDEX_REINDEX,
        }
        test_query_payload = {
            "description": "TestQueryFiltering",
            "scroll_size": 100,
            "indices": [self.test_index_name],
            "new_index": TEST_INDEX_REINDEX,
            "query": json.dumps(TEST_QUERY)
        }
        random_docs_payload = {
            "description": "TestReindexerRandomFields",
            "indices": [self.test_index_name],
            "new_index": TEST_INDEX_REINDEX,
            "random_size": 500,
        }

        update_field_type_payload = {
            "description":
            "TestReindexerUpdateFieldType",
            "fields": [],
            "indices": [self.test_index_name],
            "new_index":
            TEST_INDEX_REINDEX,
            "field_type": [
                {
                    "path": "comment_subject",
                    "field_type": "long",
                    "new_path_name": "CHANGED_NAME"
                },
                {
                    "path": "comment_content_lemmas",
                    "field_type": "fact",
                    "new_path_name": "CHANGED_TOO"
                },
                {
                    "path": "comment_content_clean.stats.text_length",
                    "field_type": "boolean",
                    "new_path_name": "CHANGED_AS_WELL"
                },
            ],
        }
        for REINDEXER_VALIDATION_TEST_INDEX in (
                REINDEXER_VALIDATION_TEST_INDEX_1,
                REINDEXER_VALIDATION_TEST_INDEX_2,
                REINDEXER_VALIDATION_TEST_INDEX_3,
                REINDEXER_VALIDATION_TEST_INDEX_4,
                REINDEXER_VALIDATION_TEST_INDEX_5,
                REINDEXER_VALIDATION_TEST_INDEX_6):
            new_index_validation_payload = {
                "description": "TestNewIndexValidation",
                "indices": [self.test_index_name],
                "new_index": REINDEXER_VALIDATION_TEST_INDEX
            }
            url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}/elastic/reindexer/'
            self.check_new_index_validation(url, new_index_validation_payload)

        for payload in (
                existing_new_index_payload,
                wrong_indices_payload,
                wrong_fields_payload,
                pick_fields_payload,
                join_indices_fields_payload,
                test_query_payload,
                random_docs_payload,
                update_field_type_payload,
        ):
            url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}/elastic/reindexer/'
            self.run_create_reindexer_task_signal(self.project, url, payload)

    def run_create_reindexer_task_signal(self,
                                         project,
                                         url,
                                         payload,
                                         overwrite=False):
        """ Tests the endpoint for a new Reindexer task, and if a new Task gets created via the signal
           checks if new_index was removed """
        try:
            self.ec.delete_index(TEST_INDEX_REINDEX)
        except:
            print(f'{TEST_INDEX_REINDEX} was not deleted')
        response = self.client.post(url, payload, format='json')
        print_output('run_create_reindexer_task_signal:response.data',
                     response.data)
        self.check_update_forbidden(url, payload)
        self.is_new_index_created_if_yes_remove(response, payload, project)
        self.is_reindexed_index_added_to_project_if_yes_remove(
            response, payload['new_index'], project)
        assert TEST_INDEX_REINDEX not in ElasticCore().get_indices()

    def check_new_index_validation(self, url, new_index_validation_payload):
        response = self.client.post(url,
                                    new_index_validation_payload,
                                    format='json')
        print_output('new_index_validation:response.data', response.data)
        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
        self.assertEqual(response.data["detail"].code, "invalid_index_name")

    def is_new_index_created_if_yes_remove(self, response, payload, project):
        """ Check if new_index gets created
            Check if new_index gets re-indexed and completed
            remove test new_index """
        if project.get_indices() is None or response.exception:
            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
        else:
            self.assertEqual(response.status_code, status.HTTP_201_CREATED)
            created_reindexer = Reindexer.objects.get(id=response.data['id'])
            print_output("Re-index task status: ",
                         created_reindexer.task.status)
            self.assertEqual(created_reindexer.task.status,
                             Task.STATUS_COMPLETED)
            # self.check_positive_doc_count()
            new_index = response.data['new_index']
            delete_response = self.ec.delete_index(new_index)
            print_output("Reindexer Test index remove status", delete_response)

    def is_reindexed_index_added_to_project_if_yes_remove(
            self, response, new_index, project):
        # project resource user is not supposed to have indices remove permission, so use admin
        self.client.login(username='******', password='******')
        url = f'{TEST_VERSION_PREFIX}/projects/{project.id}/'
        check = self.client.get(url, format='json')
        if response.status_code == 201:
            assert new_index in [
                index["name"] for index in check.data['indices']
            ]
            print_output('Re-indexed index added to project', check.data)
            index_pk = Index.objects.get(name=new_index).pk
            remove_index_url = reverse(
                f"{VERSION_NAMESPACE}:project-remove-indices",
                kwargs={"pk": self.project.pk})
            remove_response = self.client.post(remove_index_url,
                                               {"indices": [index_pk]},
                                               format='json')
            print_output("Re-indexed index removed from project",
                         remove_response.status_code)
            self.delete_reindexing_task(project, response)

        if response.status_code == 400:
            print_output('Re-indexed index not added to project', check.data)

        check = self.client.get(url, format='json')
        assert new_index not in [
            index["name"] for index in check.data['indices']
        ]
        # Log in with project user again
        self.client.login(username=self.default_username,
                          password=self.default_password)

    def validate_fields(self, project, payload):
        project_fields = self.ec.get_fields(project.get_indices())
        project_field_paths = [field["path"] for field in project_fields]
        for field in payload['fields']:
            if field not in project_field_paths:
                return False
        return True

    def validate_indices(self, project, payload):
        for index in payload['indices']:
            if index not in project.get_indices():
                return False
        return True

    def check_positive_doc_count(self):
        # current reindexing tests require approx 2 seconds delay
        sleep(5)
        count_new_documents = ElasticSearcher(
            indices=TEST_INDEX_REINDEX).count()
        print_output("Bulk add doc count", count_new_documents)
        assert count_new_documents > 0

    def check_update_forbidden(self, url, payload):
        put_response = self.client.put(url, payload, format='json')
        patch_response = self.client.patch(url, payload, format='json')
        print_output("put_response.data", put_response.data)
        print_output("patch_response.data", patch_response.data)
        self.assertEqual(put_response.status_code,
                         status.HTTP_405_METHOD_NOT_ALLOWED)
        self.assertEqual(patch_response.status_code,
                         status.HTTP_405_METHOD_NOT_ALLOWED)

    def delete_reindexing_task(self, project, response):
        """ test delete reindex task """
        task_url = response.data['url']
        get_response = self.client.get(task_url)
        self.assertEqual(get_response.status_code, status.HTTP_200_OK)
        delete_response = self.client.delete(task_url, format='json')
        self.assertEqual(delete_response.status_code,
                         status.HTTP_204_NO_CONTENT)
        get_response = self.client.get(task_url)
        self.assertEqual(get_response.status_code, status.HTTP_404_NOT_FOUND)

    def test_that_changing_field_names_works(self):
        payload = {
            "description":
            "RenameFieldName",
            "new_index":
            self.new_index_name,
            "fields": [TEST_FIELD],
            "field_type": [{
                "path": TEST_FIELD,
                "new_path_name": TEST_FIELD_RENAMED,
                "field_type": "text"
            }],
            "indices": [self.test_index_name],
            "add_facts_mapping":
            True
        }

        # Reindex the test index into a new one.
        url = reverse("v2:reindexer-list",
                      kwargs={"project_pk": self.project.pk})
        reindex_response = self.client.post(url, data=payload, format='json')
        print_output('test_that_changing_field_names_works:response.data',
                     reindex_response.data)

        # Check that the fields have been changed.
        es = ElasticSearcher(indices=[self.new_index_name])
        for document in es:
            self.assertTrue(TEST_FIELD not in document)
            self.assertTrue(TEST_FIELD_RENAMED in document)

        # Manual clean up.
        es.core.delete_index(self.new_index_name)
Beispiel #9
0
class IndexSplitterViewTests(APITransactionTestCase):
    def setUp(self):
        """ User needs to be admin, because of changed indices permissions. """
        self.test_index_name = reindex_test_dataset()
        self.default_password = '******'
        self.default_username = '******'
        self.user = create_test_user(self.default_username, '*****@*****.**',
                                     self.default_password)

        self.admin = create_test_user(name='admin', password='******')
        self.admin.is_superuser = True
        self.admin.save()
        self.project = project_creation("IndexSplittingTestProject",
                                        self.test_index_name, self.user)
        self.project.users.add(self.user)

        self.url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}/elastic/index_splitter/'

        self.client.login(username=self.default_username,
                          password=self.default_password)
        self.ec = ElasticCore()
        self.FACT = "TEEMA"

    def test_create_splitter_object_and_task_signal(self):
        payload = {
            "description": "Random index splitting",
            "indices": [{
                "name": self.test_index_name
            }],
            "train_index": INDEX_SPLITTING_TRAIN_INDEX,
            "test_index": INDEX_SPLITTING_TEST_INDEX,
            "distribution": "random",
            "test_size": 20
        }

        response = self.client.post(self.url,
                                    json.dumps(payload),
                                    content_type='application/json')

        print_output(
            'test_create_splitter_object_and_task_signal:response.data',
            response.data)

        splitter_obj = IndexSplitter.objects.get(id=response.data['id'])
        print_output("indices:", splitter_obj.get_indices())
        # Check if IndexSplitter object gets created
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
        # Check if Task gets created
        self.assertTrue(splitter_obj.task is not None)
        print_output("status of IndexSplitter's Task object",
                     splitter_obj.task.status)
        # Check if Task gets completed
        self.assertEqual(splitter_obj.task.status, Task.STATUS_COMPLETED)

        sleep(5)

        original_count = ElasticSearcher(indices=self.test_index_name).count()
        test_count = ElasticSearcher(
            indices=INDEX_SPLITTING_TEST_INDEX).count()
        train_count = ElasticSearcher(
            indices=INDEX_SPLITTING_TRAIN_INDEX).count()

        print_output('original_count, test_count, train_count',
                     [original_count, test_count, train_count])

    def test_create_random_split(self):
        payload = {
            "description": "Random index splitting",
            "indices": [{
                "name": self.test_index_name
            }],
            "train_index": INDEX_SPLITTING_TRAIN_INDEX,
            "test_index": INDEX_SPLITTING_TEST_INDEX,
            "distribution": "random",
            "test_size": 20
        }

        response = self.client.post(self.url, data=payload)
        print_output('test_create_random_split:response.data', response.data)

        splitter_obj = IndexSplitter.objects.get(id=response.data['id'])

        # Assert Task gets completed
        self.assertEqual(Task.STATUS_COMPLETED, Task.STATUS_COMPLETED)
        print_output("Task status", Task.STATUS_COMPLETED)

        sleep(5)

        original_count = ElasticSearcher(indices=self.test_index_name).count()
        test_count = ElasticSearcher(
            indices=INDEX_SPLITTING_TEST_INDEX).count()
        train_count = ElasticSearcher(
            indices=INDEX_SPLITTING_TRAIN_INDEX).count()

        print_output('original_count, test_count, train_count',
                     [original_count, test_count, train_count])
        # To avoid any inconsistencies caused by rounding assume sizes are between small limits
        self.assertTrue(self.is_between_limits(test_count, original_count,
                                               0.2))
        self.assertTrue(
            self.is_between_limits(train_count, original_count, 0.8))

    def test_create_original_split(self):
        payload = {
            "description": "Original index splitting",
            "indices": [{
                "name": self.test_index_name
            }],
            "train_index": INDEX_SPLITTING_TRAIN_INDEX,
            "test_index": INDEX_SPLITTING_TEST_INDEX,
            "distribution": "original",
            "test_size": 20,
            "fact": self.FACT
        }

        response = self.client.post(self.url, data=payload)
        print_output('test_create_original_split:response.data', response.data)

        splitter_obj = IndexSplitter.objects.get(id=response.data['id'])

        # Assert Task gets completed
        self.assertEqual(Task.STATUS_COMPLETED, Task.STATUS_COMPLETED)
        print_output("Task status", Task.STATUS_COMPLETED)

        sleep(5)

        original_distribution = ElasticAggregator(
            indices=self.test_index_name).get_fact_values_distribution(
                self.FACT)
        test_distribution = ElasticAggregator(
            indices=INDEX_SPLITTING_TEST_INDEX).get_fact_values_distribution(
                self.FACT)
        train_distribution = ElasticAggregator(
            indices=INDEX_SPLITTING_TRAIN_INDEX).get_fact_values_distribution(
                self.FACT)

        print_output(
            'original_dist, test_dist, train_dist',
            [original_distribution, test_distribution, train_distribution])

        for label, quant in original_distribution.items():
            self.assertTrue(
                self.is_between_limits(test_distribution[label], quant, 0.2))
            self.assertTrue(
                self.is_between_limits(train_distribution[label], quant, 0.8))

    def test_create_equal_split(self):
        payload = {
            "description": "Original index splitting",
            "indices": [{
                "name": self.test_index_name
            }],
            "train_index": INDEX_SPLITTING_TRAIN_INDEX,
            "test_index": INDEX_SPLITTING_TEST_INDEX,
            "distribution": "equal",
            "test_size": 20,
            "fact": self.FACT
        }

        response = self.client.post(self.url, data=payload)
        print_output('test_create_equal_split:response.data', response.data)

        splitter_obj = IndexSplitter.objects.get(id=response.data['id'])

        # Assert Task gets completed
        self.assertEqual(Task.STATUS_COMPLETED, Task.STATUS_COMPLETED)
        print_output("Task status", Task.STATUS_COMPLETED)

        sleep(5)

        original_distribution = ElasticAggregator(
            indices=self.test_index_name).get_fact_values_distribution(
                self.FACT)
        test_distribution = ElasticAggregator(
            indices=INDEX_SPLITTING_TEST_INDEX).get_fact_values_distribution(
                self.FACT)
        train_distribution = ElasticAggregator(
            indices=INDEX_SPLITTING_TRAIN_INDEX).get_fact_values_distribution(
                self.FACT)

        print_output(
            'original_dist, test_dist, train_dist',
            [original_distribution, test_distribution, train_distribution])

        for label, quant in original_distribution.items():
            if (quant > 20):
                self.assertEqual(test_distribution[label], 20)
                self.assertEqual(train_distribution[label], quant - 20)
            else:
                self.assertEqual(test_distribution[label], quant)
                self.assertTrue(label not in train_distribution)

    def test_create_custom_split(self):
        custom_distribution = {"FUBAR": 10, "bar": 15}
        payload = {
            "description": "Original index splitting",
            "indices": [{
                "name": self.test_index_name
            }],
            "train_index": INDEX_SPLITTING_TRAIN_INDEX,
            "test_index": INDEX_SPLITTING_TEST_INDEX,
            "distribution": "custom",
            "fact": self.FACT,
            "custom_distribution": json.dumps(custom_distribution)
        }

        response = self.client.post(self.url, data=payload, format="json")
        print_output('test_create_custom_split:response.data', response.data)

        splitter_obj = IndexSplitter.objects.get(id=response.data['id'])

        # Assert Task gets completed
        self.assertEqual(Task.STATUS_COMPLETED, Task.STATUS_COMPLETED)
        print_output("Task status", Task.STATUS_COMPLETED)

        sleep(5)

        original_distribution = ElasticAggregator(
            indices=self.test_index_name).get_fact_values_distribution(
                self.FACT)
        test_distribution = ElasticAggregator(
            indices=INDEX_SPLITTING_TEST_INDEX).get_fact_values_distribution(
                self.FACT)
        train_distribution = ElasticAggregator(
            indices=INDEX_SPLITTING_TRAIN_INDEX).get_fact_values_distribution(
                self.FACT)

        print_output(
            'original_dist, test_dist, train_dist',
            [original_distribution, test_distribution, train_distribution])

        for label, quant in custom_distribution.items():
            self.assertEqual(test_distribution[label],
                             min(quant, original_distribution[label]))

        for label in original_distribution.keys():
            if label not in custom_distribution:
                self.assertTrue(label not in test_distribution)
                self.assertTrue(original_distribution[label],
                                train_distribution[label])

    def test_create_original_split_fact_value_given(self):
        payload = {
            "description": "Original index splitting",
            "indices": [{
                "name": self.test_index_name
            }],
            "train_index": INDEX_SPLITTING_TRAIN_INDEX,
            "test_index": INDEX_SPLITTING_TEST_INDEX,
            "distribution": "original",
            "test_size": 20,
            "fact": self.FACT,
            "str_val": "FUBAR"
        }

        response = self.client.post(self.url, data=payload, format="json")
        print_output(
            'test_create_original_split_fact_value_given:response.data',
            response.data)

        splitter_obj = IndexSplitter.objects.get(id=response.data['id'])

        sleep(5)

        original_distribution = ElasticAggregator(
            indices=self.test_index_name).get_fact_values_distribution(
                self.FACT)
        test_distribution = ElasticAggregator(
            indices=INDEX_SPLITTING_TEST_INDEX).get_fact_values_distribution(
                self.FACT)
        train_distribution = ElasticAggregator(
            indices=INDEX_SPLITTING_TRAIN_INDEX).get_fact_values_distribution(
                self.FACT)

        print_output(
            'original_dist, test_dist, train_dist',
            [original_distribution, test_distribution, train_distribution])

        for label, quant in original_distribution.items():
            if label == "FUBAR":
                self.assertTrue(
                    self.is_between_limits(test_distribution[label], quant,
                                           0.2))
                self.assertTrue(
                    self.is_between_limits(train_distribution[label], quant,
                                           0.8))

    def test_query_given(self):
        payload = {
            "description": "Original index splitting",
            "indices": [{
                "name": self.test_index_name
            }],
            "train_index": INDEX_SPLITTING_TRAIN_INDEX,
            "test_index": INDEX_SPLITTING_TEST_INDEX,
            "distribution": "original",
            "test_size": 20,
            "fact": self.FACT,
            "str_val": "bar",
            "query": json.dumps(TEST_QUERY)
        }

        response = self.client.post(self.url, data=payload, format="json")
        print_output('test_query_given:response.data', response.data)

        original_distribution = ElasticAggregator(
            indices=self.test_index_name).get_fact_values_distribution(
                self.FACT)
        test_distribution = ElasticAggregator(
            indices=INDEX_SPLITTING_TEST_INDEX).get_fact_values_distribution(
                self.FACT)
        train_distribution = ElasticAggregator(
            indices=INDEX_SPLITTING_TRAIN_INDEX).get_fact_values_distribution(
                self.FACT)

        print_output(
            'original_dist, test_dist, train_dist',
            [original_distribution, test_distribution, train_distribution])

        self.assertTrue("bar" in test_distribution)
        self.assertTrue("bar" in train_distribution)
        self.assertTrue("foo" not in train_distribution
                        and "foo" not in test_distribution)
        self.assertTrue("FUBAR" not in train_distribution
                        and "FUBAR" not in test_distribution)

    def tearDown(self):
        self.ec.delete_index(index=self.test_index_name, ignore=[400, 404])
        res = self.ec.delete_index(INDEX_SPLITTING_TEST_INDEX)
        print_output('attempt to delete test index:', res)
        res = self.ec.delete_index(INDEX_SPLITTING_TRAIN_INDEX)
        print_output('attempt to delete train index:', res)

    def is_between_limits(self, value, base, ratio):
        return value <= base * ratio + 1 and value >= base * ratio - 1

    # There used to be a bug in which objects were flattened in the split index unintentionally.
    def test_that_split_index_with_nested_field_still_has_nested_field(self):
        payload = {
            "description": "Random index splitting",
            "indices": [{
                "name": self.test_index_name
            }],
            "train_index": INDEX_SPLITTING_TRAIN_INDEX,
            "test_index": INDEX_SPLITTING_TEST_INDEX,
            "distribution": "random",
            "test_size": 20
        }

        response = self.client.post(self.url, data=payload, format="json")
        print_output(
            'test_that_split_index_with_nested_field_still_has_nested_field:response.data',
            response.data)
        at_least_once = False
        es = ElasticSearcher(
            indices=[INDEX_SPLITTING_TEST_INDEX, INDEX_SPLITTING_TEST_INDEX],
            field_data=[TEST_INDEX_OBJECT_FIELD],
            flatten=False)
        for item in es:
            data = item.get(TEST_INDEX_OBJECT_FIELD, None)
            if data:
                self.assertTrue(isinstance(data, dict))
                at_least_once = True
        self.assertTrue(at_least_once)
Beispiel #10
0
 def tearDown(self) -> None:
     ec = ElasticCore()
     ec.delete_index(self.test_index_name)
Beispiel #11
0
class MLPIndexProcessing(APITransactionTestCase):
    def setUp(self):
        self.test_index_name = reindex_test_dataset()
        self.ec = ElasticCore()
        self.user = create_test_user('mlpUser', '*****@*****.**', 'pw')
        self.project = project_creation("mlpTestProject", self.test_index_name,
                                        self.user)
        self.project.users.add(self.user)
        self.client.login(username='******', password='******')
        self.url = reverse(f"{VERSION_NAMESPACE}:mlp_index-list",
                           kwargs={"project_pk": self.project.pk})

    def tearDown(self) -> None:
        self.ec.delete_index(self.test_index_name, ignore=[400, 404])

    def _assert_mlp_contents(self, hit: dict, test_field: str):
        self.assertTrue(f"{test_field}_mlp.lemmas" in hit)
        self.assertTrue(f"{test_field}_mlp.pos_tags" in hit)
        self.assertTrue(f"{test_field}_mlp.text" in hit)
        self.assertTrue(f"{test_field}_mlp.language.analysis" in hit)
        self.assertTrue(f"{test_field}_mlp.language.detected" in hit)

    def test_index_processing(self):
        query_string = "inimene"
        payload = {
            "description":
            "TestingIndexProcessing",
            "fields": [TEST_FIELD],
            "query":
            json.dumps(
                {'query': {
                    'match': {
                        'comment_content_lemmas': query_string
                    }
                }},
                ensure_ascii=False)
        }

        response = self.client.post(self.url, data=payload, format="json")
        print_output("test_index_processing:response.data", response.data)

        # Check if MLP was applied to the documents properly.
        s = ElasticSearcher(indices=[self.test_index_name],
                            output=ElasticSearcher.OUT_DOC,
                            query=payload["query"])
        for hit in s:
            self._assert_mlp_contents(hit, TEST_FIELD)

    def _check_for_if_query_correct(self, hit: dict, field_name: str,
                                    query_string: str):
        text = hit[field_name]
        self.assertTrue(query_string in text)

    def test_payload_without_fields_value(self):
        query_string = "inimene"
        payload = {
            "description":
            "TestingIndexProcessing",
            "query":
            json.dumps(
                {'query': {
                    'match': {
                        'comment_content_lemmas': query_string
                    }
                }},
                ensure_ascii=False)
        }
        response = self.client.post(self.url, data=payload, format="json")
        print_output("test_payload_without_fields_value:response.data",
                     response.data)
        self.assertTrue(response.status_code == status.HTTP_400_BAD_REQUEST)
        self.assertTrue(
            response.data["fields"][0] == "This field is required.")

    def test_payload_with_empty_fields_value(self):
        query_string = "inimene"
        payload = {
            "description":
            "TestingIndexProcessing",
            "query":
            json.dumps(
                {'query': {
                    'match': {
                        'comment_content_lemmas': query_string
                    }
                }},
                ensure_ascii=False),
            "fields": []
        }
        response = self.client.post(self.url, data=payload, format="json")
        print_output("test_payload_without_fields_value:response.data",
                     response.data)
        self.assertTrue(response.status_code == status.HTTP_400_BAD_REQUEST)
        self.assertTrue(
            response.data["fields"][0] == "This list may not be empty.")

    def test_payload_with_invalid_field_value(self):
        query_string = "inimene"
        payload = {
            "description":
            "TestingIndexProcessing",
            "fields": ["this_field_does_not_exist"],
            "query":
            json.dumps(
                {'query': {
                    'match': {
                        'comment_content_lemmas': query_string
                    }
                }},
                ensure_ascii=False)
        }
        response = self.client.post(self.url, data=payload, format="json")
        print_output("test_payload_with_invalid_field_value:response.data",
                     response.data)
        self.assertTrue(response.status_code == status.HTTP_400_BAD_REQUEST)

    def test_applying_mlp_on_two_indices(self):
        query_string = "inimene"
        indices = [f"texta_test_{uuid.uuid1()}", f"texta_test_{uuid.uuid1()}"]
        for index in indices:
            self.ec.es.indices.create(index=index, ignore=[400, 404])
            self.ec.es.index(index=index,
                             body={"text": "obscure content to parse!"})
            index, is_created = Index.objects.get_or_create(name=index)
            self.project.indices.add(index)

        payload = {
            "description":
            "TestingIndexProcessing",
            "fields": ["text"],
            "indices": [{
                "name": index
            } for index in indices],
            "query":
            json.dumps(
                {'query': {
                    'match': {
                        'comment_content_lemmas': query_string
                    }
                }},
                ensure_ascii=False)
        }
        response = self.client.post(self.url, data=payload, format="json")
        print_output("test_applying_mlp_on_two_indices:response.data",
                     response.data)
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)

        for index in indices:
            self.ec.es.indices.delete(index=index, ignore=[400, 404])
Beispiel #12
0
 def tearDown(self):
     # delete created indices
     ec = ElasticCore()
     for index in self.created_indices:
         delete_response = ec.delete_index(index)
         print_output("Remove index:response", delete_response)
Beispiel #13
0
 def __remove_reindexed_test_index(self):
     ec = ElasticCore()
     result = ec.delete_index(index=self.test_index_name, ignore=[400, 404])
     print_output(
         f"Deleting ProjectViewTests test index {self.test_index_name}",
         result)
Beispiel #14
0
class TestDocparserAPIView(APITestCase):

    def setUp(self) -> None:
        self.test_index_name = reindex_test_dataset()
        self.user = create_test_user('Owner', '*****@*****.**', 'pw')
        self.unauthorized_user = create_test_user('unauthorized', '*****@*****.**', 'pw')
        self.file_name = "d41d8cd98f00b204e9800998ecf8427e.txt"

        self.project = project_creation("test_doc_parser", index_title=None, author=self.user)
        self.project.users.add(self.user)
        self.unauth_project = project_creation("unauth_project", index_title=None, author=self.user)

        self.file = SimpleUploadedFile("text.txt", b"file_content", content_type="text/html")
        self.client.login(username='******', password='******')
        self._basic_pipeline_functionality()
        self.file_path = self._get_file_path()
        self.ec = ElasticCore()


    def tearDown(self) -> None:
        self.ec.delete_index(index=self.test_index_name, ignore=[400, 404])


    def _get_file_path(self):
        path = pathlib.Path(settings.RELATIVE_PROJECT_DATA_PATH) / str(self.project.pk) / "docparser" / self.file_name
        return path


    def _basic_pipeline_functionality(self):
        url = reverse(f"{VERSION_NAMESPACE}:docparser")
        payload = {
            "file": self.file,
            "project_id": self.project.pk,
            "indices": [self.test_index_name],
            "file_name": self.file_name
        }
        response = self.client.post(url, data=payload)
        print_output("_basic_pipeline_functionality:response.data", response.data)
        self.assertTrue(response.status_code == status.HTTP_200_OK)


    def test_file_appearing_in_proper_structure(self):
        print_output("test_file_appearing_in_proper_structure", self.file_path.exists())
        self.assertTrue(self.file_path.exists())


    def test_being_rejected_without_login(self):
        url = reverse(f"{VERSION_NAMESPACE}:docparser")
        self.client.logout()
        payload = {
            "file": self.file,
            "project_id": self.project.pk,
            "indices": [self.test_index_name]
        }
        response = self.client.post(url, data=payload)
        print_output("test_being_rejected_without_login:response.data", response.data)
        response_code = response.status_code
        print_output("test_being_rejected_without_login:response.status_code", response_code)

        self.assertTrue((response_code == status.HTTP_403_FORBIDDEN) or (response_code == status.HTTP_401_UNAUTHORIZED))


    def test_being_rejected_with_wrong_project_id(self):
        url = reverse(f"{VERSION_NAMESPACE}:docparser")
        payload = {
            "file": self.file,
            "project_id": self.unauth_project.pk,
            "indices": [self.test_index_name]
        }
        self.unauth_project.users.remove(self.user)
        response = self.client.post(url, data=payload)
        print_output("test_being_rejected_with_wrong_project_id:response.data", response.data)
        self.assertTrue(response.status_code == status.HTTP_403_FORBIDDEN)


    def test_indices_being_added_into_the_project(self):
        project = Project.objects.get(pk=self.project.pk)
        indices = project.indices.all()
        added_index = indices.filter(name=self.test_index_name)
        self.assertTrue(added_index.count() == 1)
        print_output("test_indices_being_added_into_the_project", True)


    def test_that_serving_media_works_for_authenticated_users(self):
        file_name = self.file_path.name
        url = reverse("protected_serve", kwargs={"project_id": self.project.pk, "application": "docparser", "file_name": file_name})
        response = self.client.get(url)
        print_output("test_that_serving_media_works_for_authenticated_users", True)
        self.assertTrue(response.status_code == status.HTTP_200_OK)


    def test_that_serving_media_doesnt_work_for_unauthenticated_users(self):
        self.client.logout()
        file_name = self.file_path.name
        url = reverse("protected_serve", kwargs={"project_id": self.project.pk, "application": "docparser", "file_name": file_name})
        response = self.client.get(url)
        print_output("test_that_serving_media_doesnt_work_for_unauthenticated_users", True)
        self.assertTrue(response.status_code == status.HTTP_302_FOUND)


    def test_media_access_for_unauthorized_projects(self):
        self.client.login(username="******", password="******")
        file_name = self.file_path.name
        url = reverse("protected_serve", kwargs={"project_id": self.project.pk, "application": "docparser", "file_name": file_name})
        response = self.client.get(url)
        print_output("test_media_access_for_unauthorized_projects", True)
        self.assertTrue(response.status_code == status.HTTP_403_FORBIDDEN)


    def test_that_saved_file_size_isnt_zero(self):
        """
        Necessary because of a prior bug where the wrapper would save a file
        with the right name but not it's contents.
        """
        import time
        time.sleep(10)
        file_size = os.path.getsize(self.file_path)
        self.assertTrue(file_size > 1)
        print_output("test_that_saved_file_size_isnt_zero::file_size:int", file_size)


    def test_payload_with_empty_indices(self):
        url = reverse(f"{VERSION_NAMESPACE}:docparser")
        payload = {
            "file": SimpleUploadedFile("text.txt", b"file_content", content_type="text/html"),
            "project_id": self.project.pk,
            "file_name": self.file_name
        }
        response = self.client.post(url, data=payload)
        print_output("_basic_pipeline_functionality:response.data", response.data)
        self.assertTrue(response.status_code == status.HTTP_200_OK)
Beispiel #15
0
class DocumentImporterAPITestCase(APITestCase):

    def setUp(self):
        self.test_index_name = reindex_test_dataset()
        self.user = create_test_user('first_user', '*****@*****.**', 'pw')
        self.project = project_creation("DocumentImporterAPI", self.test_index_name, self.user)

        self.validation_project = project_creation("validation_project", "random_index_name", self.user)

        self.document_id = random.randint(10000000, 90000000)
        self.uuid = uuid.uuid1()
        self.source = {"hello": "world", "uuid": self.uuid}
        self.document = {"_index": self.test_index_name, "_id": self.document_id, "_source": self.source}

        self.target_field_random_key = uuid.uuid1()
        self.target_field = f"{self.target_field_random_key}_court_case"
        self.ec = ElasticCore()

        self.client.login(username='******', password='******')
        self._check_inserting_documents()


    def _check_inserting_documents(self):
        url = reverse(f"{VERSION_NAMESPACE}:document_import", kwargs={"pk": self.project.pk})
        response = self.client.post(url, data={"documents": [self.document], "split_text_in_fields": []}, format="json")
        self.assertTrue(response.status_code == status.HTTP_200_OK)
        document = self.ec.es.get(id=self.document_id, index=self.test_index_name)
        print_output("_check_inserting_documents:response.data", response.data)
        self.assertTrue(document["_source"])


    def tearDown(self) -> None:
        self.ec.delete_index(index=self.test_index_name, ignore=[400, 404])
        query = Search().query(Q("exists", field=self.target_field)).to_dict()
        self.ec.es.delete_by_query(index="*", body=query, wait_for_completion=True)


    def test_adding_documents_to_false_project(self):
        url = reverse(f"{VERSION_NAMESPACE}:document_import", kwargs={"pk": self.validation_project.pk})
        self.validation_project.users.remove(self.user)
        response = self.client.post(url, data={"documents": [self.document]}, format="json")
        self.assertTrue(response.status_code == status.HTTP_403_FORBIDDEN)
        print_output("test_adding_documents_to_false_project:response.data", response.data)


    def test_adding_documents_to_false_index(self):
        url = reverse(f"{VERSION_NAMESPACE}:document_import", kwargs={"pk": self.project.pk})
        index_name = "wrong_index"
        response = self.client.post(url, data={"documents": [{"_index": index_name, "_source": self.document}]}, format="json")
        try:
            self.ec.es.get(id=self.document_id, index=index_name)
        except NotFoundError:
            print_output("test_adding_documents_to_false_index:response.data", response.data)
        else:
            raise Exception("Elasticsearch indexed a document it shouldn't have!")


    def test_updating_document(self):
        url = reverse(f"{VERSION_NAMESPACE}:document_instance", kwargs={"pk": self.project.pk, "index": self.test_index_name, "document_id": self.document_id})
        response = self.client.patch(url, data={"hello": "night", "goodbye": "world"})
        self.assertTrue(response.status_code == status.HTTP_200_OK)
        document = self.ec.es.get(index=self.test_index_name, id=self.document_id)["_source"]
        self.assertTrue(document["hello"] == "night" and document["goodbye"] == "world")
        print_output("test_updating_document:response.data", response.data)


    def test_deleting_document(self):
        url = reverse(f"{VERSION_NAMESPACE}:document_instance", kwargs={"pk": self.project.pk, "index": self.test_index_name, "document_id": self.document_id})
        response = self.client.delete(url)
        self.assertTrue(response.status_code == status.HTTP_200_OK)
        try:
            self.ec.es.get(id=self.document_id, index=self.test_index_name)
        except NotFoundError:
            print_output("test_deleting_document:response.data", response.data)
        else:
            raise Exception("Elasticsearch didnt delete a document it should have!")


    def test_unauthenticated_access(self):
        self.client.logout()
        url = reverse(f"{VERSION_NAMESPACE}:document_import", kwargs={"pk": self.project.pk})
        response = self.client.post(url, data={"documents": [self.document]}, format="json")
        print_output("test_unauthenticated_access:response.data", response.data)
        self.assertTrue(response.status_code == status.HTTP_403_FORBIDDEN or response.status_code == status.HTTP_401_UNAUTHORIZED)


    def test_adding_document_without_specified_index_and_that_index_is_added_into_project(self):
        from toolkit.elastic.document_importer.views import DocumentImportView
        sample_id = 65959645
        url = reverse(f"{VERSION_NAMESPACE}:document_import", kwargs={"pk": self.project.pk})
        response = self.client.post(url, data={"documents": [{"_source": self.source, "_id": sample_id}]}, format="json")
        self.assertTrue(response.status_code == status.HTTP_200_OK)
        normalized_project_title = DocumentImportView.get_new_index_name(self.project.pk)
        document = self.ec.es.get(id=sample_id, index=normalized_project_title)

        self.ec.delete_index(normalized_project_title)  # Cleanup

        self.assertTrue(document["_source"])
        self.assertTrue(self.project.indices.filter(name=normalized_project_title).exists())
        print_output("test_adding_document_without_specified_index_and_that_index_is_added_into_project:response.data", response.data)


    def test_updating_non_existing_document(self):
        sample_id = "random_id"
        url = reverse(f"{VERSION_NAMESPACE}:document_instance", kwargs={"pk": self.project.pk, "index": self.test_index_name, "document_id": sample_id})
        response = self.client.patch(url, data={"hello": "world"}, format="json")
        print_output("test_updating_non_existing_document:response.data", response.data)
        self.assertTrue(response.status_code == status.HTTP_404_NOT_FOUND)


    def test_deleting_non_existing_document(self):
        sample_id = "random_id"
        url = reverse(f"{VERSION_NAMESPACE}:document_instance", kwargs={"pk": self.project.pk, "index": self.test_index_name, "document_id": sample_id})
        response = self.client.delete(url, data={"hello": "world"}, format="json")
        print_output("test_deleting_non_existing_document:response.data", response.data)
        self.assertTrue(response.status_code == status.HTTP_404_NOT_FOUND)


    def test_that_specified_field_is_being_split(self):
        url = reverse(f"{VERSION_NAMESPACE}:document_import", kwargs={"pk": self.project.pk})
        uuid = "456694-asdasdad4-54646ad-asd4a5d"
        response = self.client.post(
            url,
            format="json",
            data={
                "split_text_in_fields": [self.target_field],
                "documents": [{
                    "_index": self.test_index_name,
                    "_source": {
                        self.target_field: "Paradna on kohtu alla antud kokkuleppe alusel selles, et tema, 25.10.2003 kell 00.30 koos,...",
                        "uuid": uuid
                    }
                }]},
        )
        print_output("test_that_specified_field_is_being_split:response.data", response.data)
        self.assertTrue(response.status_code == status.HTTP_200_OK)
        documents = self.ec.es.search(index=self.test_index_name, body={"query": {"term": {"uuid.keyword": uuid}}})
        document = documents["hits"]["hits"][0]
        self.assertTrue(document)
        self.assertTrue(document["_source"])
        self.assertTrue("page" in document["_source"])


    def test_that_wrong_field_value_will_skip_splitting(self):
        url = reverse(f"{VERSION_NAMESPACE}:document_import", kwargs={"pk": self.project.pk})
        uuid = "Adios"
        response = self.client.post(
            url,
            format="json",
            data={
                "documents": [{
                    "_index": self.test_index_name,
                    "_source": {
                        self.target_field: "Paradna on kohtu alla antud kokkuleppe alusel selles, et tema, 25.10.2003 kell 00.30 koos,...",
                        "uuid": uuid
                    }
                }]},
        )
        print_output("test_that_empty_field_value_will_skip_splitting:response.data", response.data)
        self.assertTrue(response.status_code == status.HTTP_200_OK)
        documents = self.ec.es.search(index=self.test_index_name, body={"query": {"term": {"uuid.keyword": uuid}}})
        document = documents["hits"]["hits"][0]
        self.assertTrue(document)
        self.assertTrue(document["_source"])
        self.assertTrue("page" not in document["_source"])


    def test_splitting_behaviour_with_empty_list_as_input(self):
        url = reverse(f"{VERSION_NAMESPACE}:document_import", kwargs={"pk": self.project.pk})
        uuid = "adasdasd-5g465s-fa4s69f4a8s97-a4das9f4"
        response = self.client.post(
            url,
            format="json",
            data={
                "split_text_in_fields": [],
                "documents": [{
                    "_index": self.test_index_name,
                    "_source": {
                        self.target_field: "Paradna on kohtu alla antud kokkuleppe alusel selles, et tema, 25.10.2003 kell 00.30 koos,...",
                        "uuid": uuid
                    }
                }]},
        )
        print_output("test_splitting_behaviour_with_empty_list_as_input:response.data", response.data)
        self.assertTrue(response.status_code == status.HTTP_200_OK)
        documents = self.ec.es.search(index=self.test_index_name, body={"query": {"term": {"uuid.keyword": uuid}}})
        document = documents["hits"]["hits"][0]
        self.assertTrue(document)
        self.assertTrue(document["_source"])
        self.assertTrue("page" not in document["_source"])


    def test_that_indexes_do_rollover_after_a_certain_count(self):
        set_core_setting("TEXTA_ES_MAX_DOCS_PER_INDEX", "50")
        url = reverse(f"{VERSION_NAMESPACE}:document_import", kwargs={"pk": self.project.pk})
        response = self.client.post(url, data={"documents": [{"_source": {"value": i}} for i in range(51)], "split_text_in_fields": []}, format="json")
        print_output("test_that_indexes_do_rollover_after_a_certain_count:response.data", response.data)
        self.assertTrue(response.status_code == status.HTTP_200_OK)

        # Send a batch for the second time to see if the documents rolled over properly.
        clean_up_container = ["texta-1-import-project-1"]
        for i in range(1, 12):
            response = self.client.post(url, data={"documents": [{"_source": {"value": i}} for i in range(51)], "split_text_in_fields": []}, format="json")

            self.assertTrue(response.status_code == status.HTTP_200_OK)
            if i == 1:
                self.assertTrue(self.project.indices.filter(name="texta-1-import-project-1").exists())
            else:
                # i-1 because the first request doesn't add the rollover index counter.
                rollover_index_name = f"texta-1-import-project-1-{i - 1}"
                self.assertTrue(self.project.indices.filter(name=rollover_index_name).exists())
                clean_up_container.append(rollover_index_name)

        # Cleanup
        self.ec.delete_index(index=",".join(clean_up_container))
Beispiel #16
0
class TorchTaggerViewTests(APITransactionTestCase):
    def setUp(self):
        # Owner of the project
        self.test_index_name = reindex_test_dataset()
        self.user = create_test_user('torchTaggerOwner', '*****@*****.**', 'pw')
        self.project = project_creation('torchTaggerTestProject',
                                        self.test_index_name, self.user)
        self.project.users.add(self.user)
        self.url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}/torchtaggers/'
        self.project_url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}'
        self.test_embedding_id = None
        self.torch_models = list(TORCH_MODELS.keys())
        self.test_tagger_id = None
        self.test_multiclass_tagger_id = None

        self.client.login(username='******', password='******')

        # new fact name and value used when applying tagger to index
        self.new_fact_name = "TEST_TORCH_TAGGER_NAME"
        self.new_multiclass_fact_name = "TEST_TORCH_TAGGER_NAME_MC"
        self.new_fact_value = "TEST_TORCH_TAGGER_VALUE"

        # Create copy of test index
        self.reindex_url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}/elastic/reindexer/'
        # Generate name for new index containing random id to make sure it doesn't already exist
        self.test_index_copy = f"test_apply_torch_tagger_{uuid.uuid4().hex}"

        self.reindex_payload = {
            "description": "test index for applying taggers",
            "indices": [self.test_index_name],
            "query": json.dumps(TEST_QUERY),
            "new_index": self.test_index_copy,
            "fields": TEST_FIELD_CHOICE
        }
        resp = self.client.post(self.reindex_url,
                                self.reindex_payload,
                                format='json')
        print_output(
            "reindex test index for applying torch tagger:response.data:",
            resp.json())
        self.reindexer_object = Reindexer.objects.get(pk=resp.json()["id"])
        self.ec = ElasticCore()

    def import_test_model(self, file_path: str):
        """Import models for testing."""
        print_output("Importing model from file:", file_path)
        files = {"file": open(file_path, "rb")}
        import_url = f'{self.url}import_model/'
        resp = self.client.post(import_url,
                                data={
                                    'file': open(file_path, "rb")
                                }).json()
        print_output("Importing test model:", resp)
        return resp["id"]

    def tearDown(self) -> None:
        res = self.ec.delete_index(self.test_index_copy)
        self.ec.delete_index(index=self.test_index_name, ignore=[400, 404])
        print_output(
            f"Delete apply_torch_taggers test index {self.test_index_copy}",
            res)

    def test(self):
        pass
        # self.run_train_embedding()
        # self.run_train_tagger_using_query()
        # self.run_train_torchtagger_without_embedding()
        # self.run_train_multiclass_tagger_using_fact_name()
        # self.run_train_balanced_multiclass_tagger_using_fact_name()
        # self.run_train_binary_multiclass_tagger_using_fact_name()
        # self.run_train_binary_multiclass_tagger_using_fact_name_invalid_payload()
        # self.run_tag_text()
        # self.run_model_export_import()
        # self.run_tag_with_imported_gpu_model() # were already commented out
        # self.run_tag_with_imported_cpu_model() # were already commented out
        # self.run_tag_random_doc()
        # self.run_epoch_reports_get()
        # self.run_epoch_reports_post()
        # self.run_tag_and_feedback_and_retrain()
        # self.run_apply_binary_tagger_to_index()
        # self.run_apply_tagger_to_index_invalid_input()

    def add_cleanup_files(self, tagger_id):
        tagger_object = TorchTagger.objects.get(pk=tagger_id)
        self.addCleanup(remove_file, tagger_object.model.path)
        if not TEST_KEEP_PLOT_FILES:
            self.addCleanup(remove_file, tagger_object.plot.path)
        self.addCleanup(remove_file,
                        tagger_object.embedding.embedding_model.path)

    def run_train_embedding(self):
        # payload for training embedding
        payload = {
            "description": "TestEmbedding",
            "fields": TEST_FIELD_CHOICE,
            "max_vocab": 10000,
            "min_freq": 5,
            "num_dimensions": 300
        }
        # post
        embeddings_url = f'{TEST_VERSION_PREFIX}/projects/{self.project.id}/embeddings/'
        response = self.client.post(embeddings_url, payload, format='json')
        self.test_embedding_id = response.data["id"]
        print_output("run_train_embedding", 201)

    def run_train_torchtagger_without_embedding(self):
        payload = {
            "description": "TestTorchTaggerTraining",
            "fields": TEST_FIELD_CHOICE,
            "maximum_sample_size": 500,
            "model_architecture": self.torch_models[0],
            "num_epochs": 3
        }

        response = self.client.post(self.url, payload, format='json')
        print_output(f"run_train_torchtagger_without_embedding", response.data)
        self.assertTrue(response.status_code == status.HTTP_400_BAD_REQUEST)

    def run_train_tagger_using_query(self):
        """Tests TorchTagger training, and if a new Task gets created via the signal"""
        payload = {
            "description": "TestTorchTaggerTraining",
            "fields": TEST_FIELD_CHOICE,
            "query": json.dumps(TEST_QUERY),
            "maximum_sample_size": 500,
            "model_architecture": self.torch_models[0],
            "num_epochs": 3,
            "embedding": self.test_embedding_id,
        }

        print_output(f"training tagger with payload: {payload}", 200)
        response = self.client.post(self.url, payload, format='json')
        print_output(
            'test_create_binary_torchtagger_training_and_task_signal:response.data',
            response.data)

        # Check if Neurotagger gets created
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)

        # Check if f1 not NULL (train and validation success)
        tagger_id = response.data['id']
        response = self.client.get(f'{self.url}{tagger_id}/')
        print_output('test_torchtagger_has_stats:response.data', response.data)
        for score in ['f1_score', 'precision', 'recall', 'accuracy']:
            self.assertTrue(isinstance(response.data[score], float))

        print_output('test_torchtagger_has_classes:response.data.classes',
                     response.data["classes"])
        self.assertTrue(isinstance(response.data["classes"], list))
        self.assertTrue(len(response.data["classes"]) == 2)

        self.test_tagger_id = tagger_id
        # add cleanup
        self.add_cleanup_files(tagger_id)

    def run_train_multiclass_tagger_using_fact_name(self):
        """Tests TorchTagger training with multiple classes and if a new Task gets created via the signal"""
        payload = {
            "description": "TestTorchTaggerTraining",
            "fact_name": TEST_FACT_NAME,
            "fields": TEST_FIELD_CHOICE,
            "maximum_sample_size": 500,
            "model_architecture": self.torch_models[0],
            "num_epochs": 3,
            "embedding": self.test_embedding_id,
        }
        response = self.client.post(self.url, payload, format='json')
        print_output(
            'test_create_multiclass_torchtagger_training_and_task_signal:response.data',
            response.data)
        # Check if Neurotagger gets created
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
        # Check if f1 not NULL (train and validation success)
        tagger_id = response.data['id']
        response = self.client.get(f'{self.url}{tagger_id}/')
        print_output('test_torchtagger_has_stats:response.data', response.data)
        for score in ['f1_score', 'precision', 'recall', 'accuracy']:
            self.assertTrue(isinstance(response.data[score], float))

        print_output('test_torchtagger_has_classes:response.data.classes',
                     response.data["classes"])
        self.assertTrue(isinstance(response.data["classes"], list))
        self.assertTrue(len(response.data["classes"]) > 2)

        self.test_multiclass_tagger_id = tagger_id
        # add cleanup
        self.add_cleanup_files(tagger_id)

    def run_train_binary_multiclass_tagger_using_fact_name(self):
        """Tests TorchTagger training with binary facts."""
        payload = {
            "description": "TestBinaryMulticlassTorchTaggerTraining",
            "fact_name": TEST_FACT_NAME,
            "fields": TEST_FIELD_CHOICE,
            "maximum_sample_size": 500,
            "model_architecture": self.torch_models[0],
            "num_epochs": 3,
            "embedding": self.test_embedding_id,
            "pos_label": TEST_POS_LABEL,
            "query": json.dumps(TEST_BIN_FACT_QUERY)
        }
        response = self.client.post(self.url, payload, format='json')
        print_output(
            'test_create_binary_multiclass_torchtagger_training_and_task_signal:response.data',
            response.data)
        # Check if Neurotagger gets created
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
        # Check if f1 not NULL (train and validation success)
        tagger_id = response.data['id']
        response = self.client.get(f'{self.url}{tagger_id}/')
        print_output('test_torchtagger_has_stats:response.data', response.data)
        for score in ['f1_score', 'precision', 'recall', 'accuracy']:
            self.assertTrue(isinstance(response.data[score], float))

        print_output('test_torchtagger_has_classes:response.data.classes',
                     response.data["classes"])
        self.assertTrue(isinstance(response.data["classes"], list))
        self.assertTrue(len(response.data["classes"]) == 2)
        # add cleanup
        self.add_cleanup_files(tagger_id)

    def run_train_binary_multiclass_tagger_using_fact_name_invalid_payload(
            self):
        """Tests TorchTagger training with binary facts and invalid payload."""

        # Pos label is undefined by the user
        invalid_payload_1 = {
            "description":
            "TestBinaryMulticlassTorchTaggerTrainingMissingPosLabel",
            "fact_name": TEST_FACT_NAME,
            "fields": TEST_FIELD_CHOICE,
            "maximum_sample_size": 500,
            "model_architecture": self.torch_models[0],
            "num_epochs": 3,
            "embedding": self.test_embedding_id,
            "query": json.dumps(TEST_BIN_FACT_QUERY)
        }
        response = self.client.post(self.url, invalid_payload_1, format='json')
        print_output(
            'test_create_binary_multiclass_torchtagger_using_fact_name_missing_pos_label:response.data',
            response.data)
        # Check if creating the Tagger fails with status code 400
        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

        # The inserted pos label is not present in the data
        invalid_payload_2 = {
            "description":
            "TestBinaryMulticlassTorchTaggerTrainingMissingPosLabel",
            "fact_name": TEST_FACT_NAME,
            "fields": TEST_FIELD_CHOICE,
            "maximum_sample_size": 500,
            "model_architecture": self.torch_models[0],
            "num_epochs": 3,
            "embedding": self.test_embedding_id,
            "query": json.dumps(TEST_BIN_FACT_QUERY),
            "pos_label": "invalid_fact_val"
        }
        response = self.client.post(self.url, invalid_payload_2, format='json')
        print_output(
            'test_create_binary_multiclass_torchtagger_using_fact_name_invalid_pos_label:response.data',
            response.data)
        # Check if creating the Tagger fails with status code 400
        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

    def run_train_balanced_multiclass_tagger_using_fact_name(self):
        """Tests TorchTagger training with multiple balanced classes and if a new Task gets created via the signal"""
        payload = {
            "description": "TestBalancedTorchTaggerTraining",
            "fact_name": TEST_FACT_NAME,
            "fields": TEST_FIELD_CHOICE,
            "maximum_sample_size": 150,
            "model_architecture": self.torch_models[0],
            "num_epochs": 2,
            "embedding": self.test_embedding_id,
            "balance": True,
            "use_sentence_shuffle": True,
            "balance_to_max_limit": True
        }
        response = self.client.post(self.url, payload, format='json')
        print_output(
            'test_create_balanced_multiclass_torchtagger_training_and_task_signal:response.data',
            response.data)
        # Check if Neurotagger gets created
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
        # Check if f1 not NULL (train and validation success)
        tagger_id = response.data['id']
        response = self.client.get(f'{self.url}{tagger_id}/')
        print_output('test_balanced_torchtagger_has_stats:response.data',
                     response.data)
        for score in ['f1_score', 'precision', 'recall', 'accuracy']:
            self.assertTrue(isinstance(response.data[score], float))

        num_examples = json.loads(response.data["num_examples"])
        print_output(
            'test_balanced_torchtagger_num_examples_correct:num_examples',
            num_examples)
        for class_size in num_examples.values():
            self.assertTrue(class_size, payload["maximum_sample_size"])

        print_output('test_balanced_torchtagger_has_classes:classes',
                     response.data["classes"])
        self.assertTrue(isinstance(response.data["classes"], list))
        self.assertTrue(len(response.data["classes"]) >= 2)
        # add cleanup
        self.add_cleanup_files(tagger_id)

    def run_tag_text(self):
        """Tests tag prediction for texts."""
        payload = {"text": "mine kukele, kala"}
        response = self.client.post(
            f'{self.url}{self.test_tagger_id}/tag_text/', payload)
        print_output('test_torchtagger_tag_text:response.data', response.data)
        self.assertTrue(isinstance(response.data, dict))
        self.assertTrue('result' in response.data)
        self.assertTrue('probability' in response.data)
        self.assertTrue('tagger_id' in response.data)

    def run_tag_random_doc(self):
        """Tests the endpoint for the tag_random_doc action"""
        payload = {"indices": [{"name": self.test_index_name}]}
        url = f'{self.url}{self.test_tagger_id}/tag_random_doc/'
        response = self.client.post(url, format="json", data=payload)
        print_output('test_tag_random_doc:response.data', response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        # Check if response is list
        self.assertTrue(isinstance(response.data, dict))
        self.assertTrue('prediction' in response.data)
        self.assertTrue('result' in response.data['prediction'])
        self.assertTrue('probability' in response.data['prediction'])
        self.assertTrue('tagger_id' in response.data['prediction'])

    def run_epoch_reports_get(self):
        """Tests endpoint for retrieving epoch reports via GET"""
        url = f'{self.url}{self.test_tagger_id}/epoch_reports/'
        response = self.client.get(url, format="json")
        print_output('test_torchagger_epoch_reports_get:response.data',
                     response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        # Check if response is a list
        self.assertTrue(isinstance(response.data, list))
        # Check if first report is not empty
        self.assertTrue(len(response.data[0]) > 0)

    def run_epoch_reports_post(self):
        """Tests endpoint for retrieving epoch reports via GET"""
        url = f'{self.url}{self.test_tagger_id}/epoch_reports/'
        payload_1 = {}
        payload_2 = {
            "ignore_fields":
            ["true_positive_rate", "false_positive_rate", "recall"]
        }

        response = self.client.post(url, format="json", data=payload_1)
        print_output(
            'test_torchagger_epoch_reports_post_ignore_default:response.data',
            response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)

        # Check if response is a list
        self.assertTrue(isinstance(response.data, list))
        # Check if first report contains recall
        self.assertTrue("recall" in response.data[0])

        response = self.client.post(url, format="json", data=payload_2)
        print_output(
            'test_torchtagger_epoch_reports_post_ignore_custom:response.data',
            response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        # Check if response is a list
        self.assertTrue(isinstance(response.data, list))
        # Check if first report does NOT contains recall
        self.assertTrue("recall" not in response.data[0])

    def run_model_export_import(self):
        """Tests endpoint for model export and import"""
        test_tagger_group_id = self.test_tagger_id

        # retrieve model zip
        url = f'{self.url}{test_tagger_group_id}/export_model/'
        response = self.client.get(url)

        # Post model zip
        import_url = f'{self.url}import_model/'
        response = self.client.post(import_url,
                                    data={'file': BytesIO(response.content)})
        tagger_id = response.data['id']
        print_output('test_import_model:response.data', import_url)

        self.assertEqual(response.status_code, status.HTTP_201_CREATED)

        torchtagger = TorchTagger.objects.get(pk=response.data["id"])

        # Check if the models and plot files exist.
        resources = torchtagger.get_resource_paths()
        for path in resources.values():
            file = pathlib.Path(path)
            self.assertTrue(file.exists())

        # Tests the endpoint for the tag_random_doc action"""
        url = f'{self.url}{torchtagger.pk}/tag_random_doc/'
        payload = {"indices": [{"name": self.test_index_name}]}
        response = self.client.post(url, format='json', data=payload)
        print_output(
            'test_torchtagger_tag_random_doc_after_import:response.data',
            response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue(isinstance(response.data, dict))

        self.assertTrue('prediction' in response.data)
        self.assertTrue('result' in response.data['prediction'])
        self.assertTrue('probability' in response.data['prediction'])
        self.assertTrue('tagger_id' in response.data['prediction'])
        self.add_cleanup_files(tagger_id)

    def run_apply_binary_tagger_to_index(self):
        """Tests applying binary torch tagger to index using apply_to_index endpoint."""
        # Make sure reindexer task has finished
        while self.reindexer_object.task.status != Task.STATUS_COMPLETED:
            print_output(
                'test_apply_binary_torch_tagger_to_index: waiting for reindexer task to finish, current status:',
                self.reindexer_object.task.status)
            sleep(2)

        url = f'{self.url}{self.test_tagger_id}/apply_to_index/'

        payload = {
            "description": "apply torch tagger to index test task",
            "new_fact_name": self.new_fact_name,
            "new_fact_value": self.new_fact_value,
            "indices": [{
                "name": self.test_index_copy
            }],
            "fields": TEST_FIELD_CHOICE
        }
        response = self.client.post(url, payload, format='json')
        print_output('test_apply_binary_torch_tagger_to_index:response.data',
                     response.data)
        self.assertEqual(response.status_code, status.HTTP_201_CREATED)
        tagger_object = TorchTagger.objects.get(pk=self.test_tagger_id)

        # Wait til the task has finished
        while tagger_object.task.status != Task.STATUS_COMPLETED:
            print_output(
                'test_apply_binary_torch_tagger_to_index: waiting for applying tagger task to finish, current status:',
                tagger_object.task.status)
            sleep(2)

        results = ElasticAggregator(
            indices=[self.test_index_copy]).get_fact_values_distribution(
                self.new_fact_name)
        print_output(
            "test_apply_binary_torch_tagger_to_index:elastic aggerator results:",
            results)

        # Check if expected number of facts is added
        self.assertTrue(results[self.new_fact_value] > 10)

    def run_apply_tagger_to_index_invalid_input(self):
        """Tests applying multiclass torch tagger to index using apply_to_index endpoint."""

        url = f'{self.url}{self.test_tagger_id}/apply_to_index/'

        payload = {
            "description": "apply torch tagger to index test task",
            "new_fact_name": self.new_fact_name,
            "new_fact_value": self.new_fact_value,
            "fields": "invalid_field_format"
        }
        response = self.client.post(url, payload, format='json')
        print_output('test_invalid_apply_torch_tagger_to_index:response.data',
                     response.data)
        self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

        self.add_cleanup_files(self.test_tagger_id)

    def run_tag_and_feedback_and_retrain(self):
        """Tests feeback extra action."""
        tagger_id = self.test_tagger_id

        tagger_orm: TorchTagger = TorchTagger.objects.get(
            pk=self.test_tagger_id)
        model_path = pathlib.Path(tagger_orm.model.path)
        print_output(
            'run_tag_and_feedback_and_retrain:assert that previous model doesnt exist',
            data=model_path.exists())
        self.assertTrue(model_path.exists())

        payload = {
            "text": "This is some test text for the Tagger Test",
            "feedback_enabled": True
        }
        tag_text_url = f'{self.url}{tagger_id}/tag_text/'
        response = self.client.post(tag_text_url, payload)
        print_output('test_tag_text_with_feedback:response.data',
                     response.data)
        self.assertTrue('feedback' in response.data)

        # generate feedback
        fb_id = response.data['feedback']['id']
        feedback_url = f'{self.url}{tagger_id}/feedback/'
        payload = {"feedback_id": fb_id, "correct_result": "FUBAR"}
        response = self.client.post(feedback_url, payload, format='json')
        print_output('test_tag_text_with_feedback:response.data',
                     response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue(response.data)
        self.assertTrue('success' in response.data)
        # sleep for a sec to allow elastic to finish its bussiness
        sleep(1)
        # list feedback
        feedback_list_url = f'{self.url}{tagger_id}/feedback/'
        response = self.client.get(feedback_list_url)
        print_output('test_tag_text_list_feedback:response.data',
                     response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue(response.data)
        self.assertTrue(len(response.data) > 0)

        # add model files before retraining
        self.add_cleanup_files(tagger_id)

        # retrain model
        url = f'{self.url}{tagger_id}/retrain_tagger/'
        response = self.client.post(url)
        print_output('test_feedback:retrain', response.data)
        # test tagging again for this model
        payload = {"text": "This is some test text for the Tagger Test"}
        tag_text_url = f'{self.url}{tagger_id}/tag_text/'
        response = self.client.post(tag_text_url, payload)
        print_output('test_feedback_retrained_tag_doc:response.data',
                     response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue('result' in response.data)
        self.assertTrue('probability' in response.data)

        # Ensure that previous tagger is deleted properly.
        print_output(
            'test_model_retrain:assert that previous model doesnt exist',
            data=model_path.exists())
        self.assertFalse(model_path.exists())
        # Ensure that the freshly created model wasn't deleted.
        tagger_orm.refresh_from_db()
        self.assertNotEqual(tagger_orm.model.path, str(model_path))

        # delete feedback
        feedback_delete_url = f'{self.url}{tagger_id}/feedback/'
        response = self.client.delete(feedback_delete_url)
        print_output('test_tag_doc_delete_feedback:response.data',
                     response.data)
        # sleep for a sec to allow elastic to finish its bussiness
        sleep(1)
        # list feedback again to make sure its emtpy
        feedback_list_url = f'{self.url}{tagger_id}/feedback/'
        response = self.client.get(feedback_list_url)
        print_output('test_tag_doc_list_feedback_after_delete:response.data',
                     response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue(len(response.data) == 0)
        # remove created index
        feedback_index_url = f'{self.project_url}/feedback/'
        response = self.client.delete(feedback_index_url)
        print_output('test_delete_feedback_index:response.data', response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue('success' in response.data)

        # add model files after retraining
        self.add_cleanup_files(tagger_id)
Beispiel #17
0
 def tearDown(self) -> None:
     CRFExtractor.objects.all().delete()
     ec = ElasticCore()
     ec.delete_index(self.test_index_copy)