예제 #1
0
    def setUp(self):
        # Owner of the project
        self.test_index = reindex_test_dataset(from_index=TEST_INDEX_ENTITY_EVALUATOR)
        self.user = create_test_user("EvaluatorOwner", "*****@*****.**", "pw")
        self.project = project_creation("EvaluatorTestProject", self.test_index, self.user)
        self.project.users.add(self.user)
        self.url = f"{TEST_VERSION_PREFIX}/projects/{self.project.id}/evaluators/"
        self.project_url = f"{TEST_VERSION_PREFIX}/projects/{self.project.id}"

        self.true_fact_name = "PER"
        self.pred_fact_name = "PER_CRF_30"

        self.true_fact_name_sent_index = "PER_SENT"
        self.pred_fact_name_sent_index = "PER_CRF_31_SENT"

        self.fact_name_no_spans = "PER_FN_REGEX_NO_SPANS"

        self.fact_name_different_doc_paths = "PER_DOUBLE"

        self.core_variables_url = f"{TEST_VERSION_PREFIX}/core_variables/5/"

        # TODO! Construct a test query
        self.fact_names_to_filter = [self.true_fact_name, self.pred_fact_name]
        self.test_query = Query()
        self.test_query.add_facts_filter(self.fact_names_to_filter, [], operator="must")
        self.test_query = self.test_query.__dict__()

        self.client.login(username="******", password="******")

        self.token_based_evaluator_id = None
        self.value_based_evaluator_id = None
        self.token_based_sent_index_evaluator_id = None
        self.value_based_sent_index_evaluator_id = None
예제 #2
0
 def create_queries(self, fact_name, tags):
     """Creates queries for finding documents for each tag."""
     queries = []
     for tag in tags:
         query = Query()
         query.add_fact_filter(fact_name, tag)
         queries.append(query.query)
     return queries
예제 #3
0
    def setUp(self):
        # Owner of the project
        self.test_index = reindex_test_dataset(from_index=TEST_INDEX_EVALUATOR)
        self.user = create_test_user("EvaluatorOwner", "*****@*****.**", "pw")
        self.project = project_creation("EvaluatorTestProject",
                                        self.test_index, self.user)
        self.project.users.add(self.user)
        self.url = f"{TEST_VERSION_PREFIX}/projects/{self.project.id}/evaluators/"
        self.project_url = f"{TEST_VERSION_PREFIX}/projects/{self.project.id}"

        self.multilabel_avg_functions = choices.MULTILABEL_AVG_FUNCTIONS
        self.binary_avg_functions = choices.BINARY_AVG_FUNCTIONS

        self.multilabel_evaluators = {
            avg: None
            for avg in self.multilabel_avg_functions
        }
        self.binary_evaluators = {
            avg: None
            for avg in self.binary_avg_functions
        }

        self.memory_optimized_multilabel_evaluators = {
            avg: None
            for avg in self.multilabel_avg_functions
        }
        self.memory_optimized_binary_evaluators = {
            avg: None
            for avg in self.binary_avg_functions
        }

        self.true_fact_name = "TRUE_TAG"
        self.pred_fact_name = "PREDICTED_TAG"

        self.true_fact_value = "650 kapital"
        self.pred_fact_value = "650 kuvand"

        self.core_variables_url = f"{TEST_VERSION_PREFIX}/core_variables/5/"

        # Construct a test query
        self.fact_names_to_filter = [self.true_fact_name, self.pred_fact_name]
        self.fact_values_to_filter = [
            "650 bioeetika", "650 rahvusbibliograafiad"
        ]
        self.test_query = Query()
        self.test_query.add_facts_filter(self.fact_names_to_filter,
                                         self.fact_values_to_filter,
                                         operator="must")
        self.test_query = self.test_query.__dict__()

        self.client.login(username="******", password="******")
예제 #4
0
 def _initialize_es(self, project_pk, text_processor, callback_progress,
                    prediction_to_match):
     # create es doc
     es_doc = ElasticDocument(self.feedback_index)
     # if no model objects, return nones for query and search
     if not self.model_object:
         return es_doc, None, None
     # create mathing query
     query = Query()
     query.add_string_filter(query_string=self.model_object.MODEL_TYPE,
                             fields=["model_type"])
     if self.model_object:
         query.add_string_filter(query_string=str(self.model_object.pk),
                                 fields=["model_id"])
     if prediction_to_match:
         query.add_string_filter(query_string=prediction_to_match,
                                 fields=["correct_result"])
     # if no index, don't create searcher object
     if not self.check_index_exists():
         return es_doc, None, query.query
     # create es search
     es_search = ElasticSearcher(indices=self.feedback_index,
                                 query=query.query,
                                 text_processor=text_processor,
                                 output=ElasticSearcher.OUT_DOC_WITH_ID,
                                 callback_progress=callback_progress)
     # return objects
     return es_doc, es_search, query.query
예제 #5
0
def get_tag_candidates(tagger_group_id: int,
                       text: str,
                       ignore_tags: List[str] = [],
                       n_similar_docs: int = 10,
                       max_candidates: int = 10):
    """
    Finds frequent tags from documents similar to input document.
    Returns empty list if hybrid option false.
    """
    hybrid_tagger_object = TaggerGroup.objects.get(pk=tagger_group_id)
    field_paths = json.loads(hybrid_tagger_object.taggers.first().fields)
    indices = hybrid_tagger_object.get_indices()
    logging.getLogger(INFO_LOGGER).info(
        f"[Get Tag Candidates] Selecting from following indices: {indices}.")
    ignore_tags = {tag["tag"]: True for tag in ignore_tags}
    # create query
    query = Query()
    query.add_mlt(field_paths, text)
    # create Searcher object for MLT
    es_s = ElasticSearcher(indices=indices, query=query.query)
    logging.getLogger(INFO_LOGGER).info(
        f"[Get Tag Candidates] Trying to retrieve {n_similar_docs} documents from Elastic..."
    )
    docs = es_s.search(size=n_similar_docs)
    logging.getLogger(INFO_LOGGER).info(
        f"[Get Tag Candidates] Successfully retrieved {len(docs)} documents from Elastic."
    )
    # dict for tag candidates from elastic
    tag_candidates = {}
    # retrieve tags from elastic response
    for doc in docs:
        if "texta_facts" in doc:
            for fact in doc["texta_facts"]:
                if fact["fact"] == hybrid_tagger_object.fact_name:
                    fact_val = fact["str_val"]
                    if fact_val not in ignore_tags:
                        if fact_val not in tag_candidates:
                            tag_candidates[fact_val] = 0
                        tag_candidates[fact_val] += 1
    # sort and limit candidates
    tag_candidates = [
        item[0] for item in sorted(
            tag_candidates.items(), key=lambda k: k[1], reverse=True)
    ][:max_candidates]
    logging.getLogger(INFO_LOGGER).info(
        f"[Get Tag Candidates] Retrieved {len(tag_candidates)} tag candidates."
    )
    return tag_candidates
예제 #6
0
    def post(self, request, project_pk: int):
        """Simplified search interface for making Elasticsearch queries."""
        serializer = ProjectSimplifiedSearchSerializer(data=request.data)
        if not serializer.is_valid():
            raise SerializerNotValid(detail=serializer.errors)

        project_object = get_object_or_404(Project, pk=project_pk)
        self.check_object_permissions(request, project_object)
        project_indices = list(project_object.get_indices())
        project_fields = project_object.get_elastic_fields(path_list=True)
        # test if indices exist
        if not project_indices:
            raise ProjectValidationFailed(detail="Project has no indices")
        # test if indices are valid
        if serializer.validated_data['match_indices']:
            if not set(serializer.validated_data['match_indices']).issubset(set(project_indices)):
                raise ProjectValidationFailed(detail=f"Index names are not valid for this project. allowed values are: {project_indices}")
        # test if fields are valid
        if serializer.validated_data['match_fields']:
            if not set(serializer.validated_data['match_fields']).issubset(set(project_fields)):
                raise ProjectValidationFailed(detail=f"Fields names are not valid for this project. allowed values are: {project_fields}")

        es = ElasticSearcher(indices=project_indices, output=ElasticSearcher.OUT_DOC)
        q = Query(operator=serializer.validated_data['operator'])
        # if input is string, convert to list
        # if unknown format, return error
        match_text = serializer.validated_data['match_text']
        if isinstance(match_text, list):
            match_texts = [str(item) for item in match_text if item]
        elif isinstance(match_text, str):
            match_texts = [match_text]
        else:
            return Response({'error': f'match text is in unknown format: {match_text}'}, status=status.HTTP_400_BAD_REQUEST)
        # add query filters
        for item in match_texts:
            q.add_string_filter(item, match_type=serializer.validated_data["match_type"])
        # update query
        es.update_query(q.query)
        # retrieve results
        results = es.search(size=serializer.validated_data["size"])
        return Response(results, status=status.HTTP_200_OK)
예제 #7
0
class EntityEvaluatorObjectViewTests(APITransactionTestCase):
    def setUp(self):
        # Owner of the project
        self.test_index = reindex_test_dataset(from_index=TEST_INDEX_ENTITY_EVALUATOR)
        self.user = create_test_user("EvaluatorOwner", "*****@*****.**", "pw")
        self.project = project_creation("EvaluatorTestProject", self.test_index, self.user)
        self.project.users.add(self.user)
        self.url = f"{TEST_VERSION_PREFIX}/projects/{self.project.id}/evaluators/"
        self.project_url = f"{TEST_VERSION_PREFIX}/projects/{self.project.id}"

        self.true_fact_name = "PER"
        self.pred_fact_name = "PER_CRF_30"

        self.true_fact_name_sent_index = "PER_SENT"
        self.pred_fact_name_sent_index = "PER_CRF_31_SENT"

        self.fact_name_no_spans = "PER_FN_REGEX_NO_SPANS"

        self.fact_name_different_doc_paths = "PER_DOUBLE"

        self.core_variables_url = f"{TEST_VERSION_PREFIX}/core_variables/5/"

        # TODO! Construct a test query
        self.fact_names_to_filter = [self.true_fact_name, self.pred_fact_name]
        self.test_query = Query()
        self.test_query.add_facts_filter(self.fact_names_to_filter, [], operator="must")
        self.test_query = self.test_query.__dict__()

        self.client.login(username="******", password="******")

        self.token_based_evaluator_id = None
        self.value_based_evaluator_id = None
        self.token_based_sent_index_evaluator_id = None
        self.value_based_sent_index_evaluator_id = None


    def tearDown(self) -> None:
        from texta_elastic.core import ElasticCore
        ElasticCore().delete_index(index=self.test_index, ignore=[400, 404])


    def test(self):

        self.run_test_invalid_fact_name()
        self.run_test_invalid_fact_without_spans()
        self.run_test_invalid_doc_path()
        self.run_test_invalid_facts_have_different_doc_paths()
        self.run_test_invalid_fact_has_multiple_paths_field_name_unspecified()
        self.run_test_entity_evaluation_token_based()
        self.run_test_entity_evaluation_token_based_sent_index()
        self.run_test_entity_evaluation_value_based()
        self.run_test_entity_evaluation_value_based_sent_index()
        self.run_test_individual_results_view_entity(self.token_based_evaluator_id)
        self.run_test_filtered_average_view_entity(self.token_based_evaluator_id)
        self.run_test_misclassified_examples_get(self.token_based_evaluator_id)
        self.run_test_misclassified_examples_get(self.value_based_evaluator_id)
        self.run_test_misclassified_examples_get(self.token_based_sent_index_evaluator_id)
        self.run_test_misclassified_examples_get(self.value_based_sent_index_evaluator_id)
        self.run_test_misclassified_examples_post(self.token_based_evaluator_id)
        self.run_test_misclassified_examples_post(self.value_based_evaluator_id)
        self.run_test_misclassified_examples_post(self.token_based_sent_index_evaluator_id)
        self.run_test_misclassified_examples_post(self.value_based_sent_index_evaluator_id)
        self.run_test_entity_evaluation_with_query()
        self.run_export_import(self.token_based_evaluator_id)
        self.run_reevaluate(self.token_based_evaluator_id)
        self.run_delete(self.token_based_evaluator_id)
        self.run_patch(self.value_based_evaluator_id)


    def add_cleanup_files(self, evaluator_id: int):
        try:
            evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
        except:
            pass
        if not TEST_KEEP_PLOT_FILES:
            self.addCleanup(remove_file, evaluator_object.plot.path)


    def run_test_invalid_fact_name(self):
        """
        Check if evaluator endpoint throws an error if one of the
        selected fact names is not present in the selected indices.
        """
        invalid_payloads = [
            {
                "true_fact": self.true_fact_name,
                "predicted_fact": "INVALID_FACT_NAME"
            },
            {
                "true_fact": "INVALID_FACT_NAME",
                "predicted_fact": self.pred_fact_name
            }
        ]
        main_payload = {
            "description": "Test invalid fact name",
            "indices": [{"name": self.test_index}],
            "evaluation_type": "entity"
        }
        for invalid_payload in invalid_payloads:
            payload = {**main_payload, **invalid_payload}
            response = self.client.post(self.url, payload, format="json")
            print_output("entity_evaluator:run_test_invalid_fact_name:response.data", response.data)
            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)


    def run_test_invalid_fact_without_spans(self):
        """
        Check if evaluator endpoint throws an error if one of the
        selected fact names has only zero-valued spans.
        """
        invalid_payloads = [
            {
                "true_fact": self.true_fact_name,
                "predicted_fact": self.fact_name_no_spans
            },
            {
                "true_fact": self.fact_name_no_spans,
                "predicted_fact": self.pred_fact_name
            }
        ]
        main_payload = {
            "description": "Test invalid fact without spans",
            "indices": [{"name": self.test_index}],
            "evaluation_type": "entity"
        }
        for invalid_payload in invalid_payloads:
            payload = {**main_payload, **invalid_payload}
            response = self.client.post(self.url, payload, format="json")
            print_output("entity_evaluator:run_test_invalid_fact_without_spans:response.data", response.data)
            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)


    def run_test_invalid_doc_path(self):
        """
        Check if evaluator endpoint throws an error if the
        selected doc_path is invalid.
        """
        invalid_payloads = [
            {
                "true_fact": self.true_fact_name,
                "predicted_fact": self.pred_fact_name
            }
        ]
        main_payload = {
            "description": "Test invalid doc_path (field)",
            "indices": [{"name": self.test_index}],
            "evaluation_type": "entity",
            "field": "brr"
        }
        for invalid_payload in invalid_payloads:
            payload = {**main_payload, **invalid_payload}
            response = self.client.post(self.url, payload, format="json")
            print_output("entity_evaluator:run_test_invalid_doc_path:response.data", response.data)
            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)


    def run_test_invalid_facts_have_different_doc_paths(self):
        """
        Check if evaluator endpoint throws an error if the
        selected facts have different doc paths.
        """
        invalid_payloads = [
            {
                "true_fact": self.true_fact_name,
                "predicted_fact": self.pred_fact_name_sent_index
            },
            {
                "true_fact": self.true_fact_name_sent_index,
                "predicted_fact": self.pred_fact_name
            }
        ]
        main_payload = {
            "description": "Test invalid: facts have different doc paths (fields)",
            "indices": [{"name": self.test_index}],
            "evaluation_type": "entity"
        }
        for invalid_payload in invalid_payloads:
            payload = {**main_payload, **invalid_payload}
            response = self.client.post(self.url, payload, format="json")
            print_output("entity_evaluator:run_test_invalid_facts_have_different_doc_paths:response.data", response.data)
            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)


    def run_test_invalid_fact_has_multiple_paths_field_name_unspecified(self):
        """
        Check if evaluator endpoint throws an error if one of the
        selected fact names is related to more than one doc path,
        but the user hasn't specified the field.
        """
        invalid_payloads = [
            {
                "true_fact": self.true_fact_name,
                "predicted_fact": self.fact_name_different_doc_paths
            },
            {
                "true_fact": self.fact_name_different_doc_paths,
                "predicted_fact": self.pred_fact_name
            }
        ]
        main_payload = {
            "description": "Test invalid fact without spans",
            "indices": [{"name": self.test_index}],
            "evaluation_type": "entity"
        }
        for invalid_payload in invalid_payloads:
            payload = {**main_payload, **invalid_payload}
            response = self.client.post(self.url, payload, format="json")
            print_output("entity_evaluator:run_test_invalid_fact_has_multiple_paths_field_name_unspecified:response.data", response.data)
            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)


    def run_test_individual_results_view_entity(self, evaluator_id: int):
        """ Test individual_results endpoint for entity evaluators."""

        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
        evaluation_type = evaluator_object.evaluation_type

        url = f"{self.url}{evaluator_id}/individual_results/"

        default_payload = {}

        response = self.client.post(url, default_payload, format="json")
        print_output(f"entity_evaluator:run_test_individual_results_view_binary:{evaluation_type}:default_payload:response.data:", response.data)

        # The usage of the endpoint is not available for binary evaluators
        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)

        self.add_cleanup_files(evaluator_id)


    def run_test_filtered_average_view_entity(self, evaluator_id: int):
        """ Test filtered_average endpoint for binary evaluators. """

        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
        evaluation_type = evaluator_object.evaluation_type

        url = f"{self.url}{evaluator_id}/filtered_average/"

        default_payload = {}

        response = self.client.post(url, default_payload, format="json")
        print_output(f"entity_evaluator:run_test_filtered_average_view_entity:{evaluation_type}:default_payload:response.data:", response.data)

        # The usage of the endpoint is not available for binary evaluators
        self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)

        self.add_cleanup_files(evaluator_id)


    def run_test_entity_evaluation_token_based(self):
        """ Test token-based entity evaluation. """

        payload = {
            "description": "Test token-based entity evaluation",
            "indices": [{"name": self.test_index}],
            "true_fact": self.true_fact_name,
            "predicted_fact": self.pred_fact_name,
            "scroll_size": 50,
            "add_misclassified_examples": True,
            "token_based": True,
            "evaluation_type": "entity"

        }

        expected_scores = {
            "accuracy": 0.99,
            "precision": 0.84,
            "recall": 0.85,
            "f1_score": 0.84
        }


        response = self.client.post(self.url, payload, format="json")
        print_output(f"entity_evaluator:run_test_entity_evaluation_token_based:response.data", response.data)

        self.assertEqual(response.status_code, status.HTTP_201_CREATED)

        evaluator_id = response.data["id"]
        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
        while evaluator_object.task.status != Task.STATUS_COMPLETED:
            print_output(f"entity_evaluator:run_test_entity_evaluation_token_based: waiting for evaluation task to finish, current status:", evaluator_object.task.status)
            sleep(1)

        evaluator_json = evaluator_object.to_json()
        evaluator_json.pop("misclassified_examples")

        print_output(f"entity_evaluator:run_test_entity_evaluation_token_based:evaluator_object.json:", evaluator_json)

        for metric in choices.METRICS:
            self.assertEqual(round(evaluator_json[metric], 2), expected_scores[metric])

        self.assertEqual(evaluator_object.n_total_classes, 877)
        self.assertEqual(evaluator_object.n_true_classes, 757)
        self.assertEqual(evaluator_object.n_predicted_classes, 760)

        cm = np.array(json.loads(evaluator_object.confusion_matrix))
        cm_size = np.shape(cm)

        self.assertEqual(2, cm_size[0])
        self.assertEqual(2, cm_size[1])

        self.assertEqual(evaluator_object.document_count, 100)
        self.assertEqual(evaluator_object.add_individual_results, False)
        self.assertEqual(evaluator_object.scores_imprecise, False)
        self.assertEqual(evaluator_object.token_based, True)
        self.assertEqual(evaluator_object.evaluation_type, "entity")

        self.token_based_evaluator_id = evaluator_id

        self.add_cleanup_files(evaluator_id)



    def run_test_entity_evaluation_token_based_sent_index(self):
        """ Test token-based entity evaluation with sentence-level spans. """

        payload = {
            "description": "Test token-based entity evaluation with sentence-level spans",
            "indices": [{"name": self.test_index}],
            "true_fact": self.true_fact_name_sent_index,
            "predicted_fact": self.pred_fact_name_sent_index,
            "scroll_size": 50,
            "add_misclassified_examples": True,
            "token_based": True,
            "evaluation_type": "entity"

        }

        expected_scores = {
            "accuracy": 1.0,
            "precision": 0.93,
            "recall": 0.90,
            "f1_score": 0.92
        }


        response = self.client.post(self.url, payload, format="json")
        print_output(f"entity_evaluator:run_test_entity_evaluation_token_based:response.data", response.data)

        self.assertEqual(response.status_code, status.HTTP_201_CREATED)

        evaluator_id = response.data["id"]
        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
        while evaluator_object.task.status != Task.STATUS_COMPLETED:
            print_output(f"entity_evaluator:run_test_entity_evaluation_token_based_sent_index: waiting for evaluation task to finish, current status:", evaluator_object.task.status)
            sleep(1)

        evaluator_json = evaluator_object.to_json()
        evaluator_json.pop("misclassified_examples")

        print_output(f"entity_evaluator:run_test_entity_evaluation_token_based_sent_index:evaluator_object.json:", evaluator_json)

        for metric in choices.METRICS:
            self.assertEqual(round(evaluator_json[metric], 2), expected_scores[metric])

        self.assertEqual(evaluator_object.n_total_classes, 802)
        self.assertEqual(evaluator_object.n_true_classes, 754)
        self.assertEqual(evaluator_object.n_predicted_classes, 726)

        cm = np.array(json.loads(evaluator_object.confusion_matrix))
        cm_size = np.shape(cm)

        self.assertEqual(2, cm_size[0])
        self.assertEqual(2, cm_size[1])

        self.assertEqual(evaluator_object.document_count, 100)
        self.assertEqual(evaluator_object.add_individual_results, False)
        self.assertEqual(evaluator_object.scores_imprecise, False)
        self.assertEqual(evaluator_object.token_based, True)
        self.assertEqual(evaluator_object.evaluation_type, "entity")

        self.token_based_sent_index_evaluator_id = evaluator_id

        self.add_cleanup_files(evaluator_id)


    def run_test_entity_evaluation_value_based(self):
        """ Test value-based entity evaluation. """

        payload = {
            "description": "Test value-based entity evaluation",
            "indices": [{"name": self.test_index}],
            "true_fact": self.true_fact_name,
            "predicted_fact": self.pred_fact_name,
            "scroll_size": 50,
            "add_misclassified_examples": True,
            "token_based": False,
            "evaluation_type": "entity"

        }

        expected_scores = {
            "accuracy": 0.61,
            "precision": 0.68,
            "recall": 0.80,
            "f1_score": 0.73
        }


        response = self.client.post(self.url, payload, format="json")
        print_output(f"entity_evaluator:run_test_entity_evaluation_value_based:response.data", response.data)

        self.assertEqual(response.status_code, status.HTTP_201_CREATED)

        evaluator_id = response.data["id"]
        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
        while evaluator_object.task.status != Task.STATUS_COMPLETED:
            print_output(f"entity_evaluator:run_test_entity_evaluation_value_based: waiting for evaluation task to finish, current status:", evaluator_object.task.status)
            sleep(1)

        evaluator_json = evaluator_object.to_json()
        evaluator_json.pop("misclassified_examples")

        print_output(f"entity_evaluator:run_test_entity_evaluation_value_based:evaluator_object.json:", evaluator_json)

        for metric in choices.METRICS:
            self.assertEqual(round(evaluator_json[metric], 2), expected_scores[metric])

        self.assertEqual(evaluator_object.n_total_classes, 600)
        self.assertEqual(evaluator_object.n_true_classes, 437)
        self.assertEqual(evaluator_object.n_predicted_classes, 511)

        cm = np.array(json.loads(evaluator_object.confusion_matrix))
        cm_size = np.shape(cm)

        self.assertEqual(2, cm_size[0])
        self.assertEqual(2, cm_size[1])

        self.assertEqual(evaluator_object.document_count, 100)
        self.assertEqual(evaluator_object.add_individual_results, False)
        self.assertEqual(evaluator_object.scores_imprecise, False)
        self.assertEqual(evaluator_object.token_based, False)
        self.assertEqual(evaluator_object.evaluation_type, "entity")

        self.value_based_evaluator_id = evaluator_id

        self.add_cleanup_files(evaluator_id)


    def run_test_entity_evaluation_value_based_sent_index(self):
        """ Test value-based entity evaluation with sentence-level spans. """

        payload = {
            "description": "Test value-based entity evaluation with sentence-level spans",
            "indices": [{"name": self.test_index}],
            "true_fact": self.true_fact_name_sent_index,
            "predicted_fact": self.pred_fact_name_sent_index,
            "scroll_size": 50,
            "add_misclassified_examples": True,
            "token_based": False,
            "evaluation_type": "entity"

        }

        expected_scores = {
            "accuracy": 0.95,
            "precision": 0.92,
            "recall": 0.84,
            "f1_score": 0.88
        }


        response = self.client.post(self.url, payload, format="json")
        print_output(f"entity_evaluator:run_test_entity_evaluation_value_based_sent_index:response.data", response.data)

        self.assertEqual(response.status_code, status.HTTP_201_CREATED)

        evaluator_id = response.data["id"]
        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
        while evaluator_object.task.status != Task.STATUS_COMPLETED:
            print_output(f"entity_evaluator:run_test_entity_evaluation_value_based_sent_index: waiting for evaluation task to finish, current status:", evaluator_object.task.status)
            sleep(1)

        evaluator_json = evaluator_object.to_json()
        evaluator_json.pop("misclassified_examples")

        print_output(f"entity_evaluator:run_test_entity_evaluation_value_based_sent_index:evaluator_object.json:", evaluator_json)

        for metric in choices.METRICS:
            self.assertEqual(round(evaluator_json[metric], 2), expected_scores[metric])

        self.assertEqual(evaluator_object.n_total_classes, 481)
        self.assertEqual(evaluator_object.n_true_classes, 447)
        self.assertEqual(evaluator_object.n_predicted_classes, 410)

        cm = np.array(json.loads(evaluator_object.confusion_matrix))
        cm_size = np.shape(cm)

        self.assertEqual(2, cm_size[0])
        self.assertEqual(2, cm_size[1])

        self.assertEqual(evaluator_object.document_count, 100)
        self.assertEqual(evaluator_object.add_individual_results, False)
        self.assertEqual(evaluator_object.scores_imprecise, False)
        self.assertEqual(evaluator_object.token_based, False)
        self.assertEqual(evaluator_object.evaluation_type, "entity")


        self.value_based_sent_index_evaluator_id = evaluator_id

        self.add_cleanup_files(evaluator_id)


    def run_test_misclassified_examples_get(self, evaluator_id: int):
        """ Test misclassified_examples endpoint with GET request. """


        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
        token_based = evaluator_object.token_based

        url = f"{self.url}{evaluator_id}/misclassified_examples/"

        response = self.client.get(url, format="json")
        print_output(f"entity_evaluator:run_test_misclassified_examples_view_get:token_based:{token_based}:response.data:", response.data)

        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue(isinstance(response.data, dict))

        keys = ["substrings", "superstrings", "partial", "false_negatives", "false_positives"]

        for key in keys:
            self.assertTrue(key in response.data)
            self.assertTrue(isinstance(response.data[key], list))


        value_types_dict = ["substrings", "superstrings", "partial"]
        value_types_str = ["false_negatives", "false_positives"]

        for key in list(response.data.keys()):
            if response.data[key]:
                if key in value_types_dict:
                    self.assertTrue("true" in response.data[key][0]["value"])
                    self.assertTrue("pred" in response.data[key][0]["value"])
                elif key in value_types_str:
                    self.assertTrue(isinstance(response.data[key][0]["value"], str))
                self.assertTrue("count" in response.data[key][0])



    def run_test_misclassified_examples_post(self, evaluator_id: int):
        """ Test misclassified examples endpoint with POST request."""
        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
        token_based = evaluator_object.token_based


        url = f"{self.url}{evaluator_id}/misclassified_examples/"

        # Test param `min_count`
        payload_min_count = {
            "min_count": 2
        }

        response = self.client.post(url, payload_min_count, format="json")
        print_output(f"entity_evaluator:run_test_misclassified_examples_view_min_count_post:token_based:{token_based}:response.data:", response.data)

        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue(isinstance(response.data, dict))

        keys = ["substrings", "superstrings", "partial", "false_negatives", "false_positives"]

        for key in keys:
            self.assertTrue(key in response.data)
            self.assertTrue(isinstance(response.data[key], dict))
            nested_keys = ["values", "total_unique_count", "filtered_unique_count"]
            for nested_key in nested_keys:
                self.assertTrue(nested_key in response.data[key])

            # Check that the number of filtered values is smaller than or equal with the number of total values
            self.assertTrue(response.data[key]["total_unique_count"] >= response.data[key]["filtered_unique_count"])

            # Check that no value with smaller count than the min count is present in the results
            for value in response.data[key]["values"]:
                self.assertTrue(value["count"] >= payload_min_count["min_count"])



        # Test param `max_count`
        payload_max_count = {
            "max_count": 2
        }

        response = self.client.post(url, payload_max_count, format="json")
        print_output(f"entity_evaluator:run_test_misclassified_examples_view_max_count_post:token_based:{token_based}:response.data:", response.data)

        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue(isinstance(response.data, dict))

        keys = ["substrings", "superstrings", "partial", "false_negatives", "false_positives"]

        for key in keys:
            self.assertTrue(key in response.data)
            self.assertTrue(isinstance(response.data[key], dict))
            nested_keys = ["values", "total_unique_count", "filtered_unique_count"]
            for nested_key in nested_keys:
                self.assertTrue(nested_key in response.data[key])

            # Check that the number of filtered values is smaller than or equal with the number of total values
            self.assertTrue(response.data[key]["total_unique_count"] >= response.data[key]["filtered_unique_count"])

            # Check that no value with bigger count than max count is present in the results
            for value in response.data[key]["values"]:
                self.assertTrue(value["count"] <= payload_max_count["max_count"])


        # Test param `top_n`
        payload_top_n = {
            "top_n": 5
        }

        response = self.client.post(url, payload_top_n, format="json")
        print_output(f"entity_evaluator:run_test_misclassified_examples_view_top_n_post:token_based:{token_based}:response.data:", response.data)

        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue(isinstance(response.data, dict))

        keys = ["substrings", "superstrings", "partial", "false_negatives", "false_positives"]

        for key in keys:
            self.assertTrue(key in response.data)
            self.assertTrue(isinstance(response.data[key], dict))
            nested_keys = ["values", "total_unique_count", "filtered_unique_count"]
            for nested_key in nested_keys:
                self.assertTrue(nested_key in response.data[key])

            # Check that the number of filtered values is smaller than or equal with the number of total values
            self.assertTrue(response.data[key]["total_unique_count"] >= response.data[key]["filtered_unique_count"])

            # Check that at most top n values are present for each key
            self.assertTrue(response.data[key]["filtered_unique_count"] <= payload_top_n["top_n"])


    def run_test_entity_evaluation_with_query(self):
        """ Test if running the entity evaluation with query works. """

        payload = {
            "description": "Test evaluation with query",
            "indices": [{"name": self.test_index}],
            "true_fact": self.true_fact_name,
            "predicted_fact": self.pred_fact_name,
            "scroll_size": 50,
            "add_misclassified_examples": False,
            "query": self.test_query,
            "evaluation_type": "entity"
        }

        response = self.client.post(self.url, payload, format="json")
        print_output(f"entity_evaluator:run_test_entity_evaluation_with_query:response.data", response.data)

        evaluator_id = response.data["id"]
        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)

        while evaluator_object.task.status != Task.STATUS_COMPLETED:
            print_output(f"entity_evaluator:run_test_evaluation_with_query: waiting for evaluation task to finish, current status:", evaluator_object.task.status)
            sleep(1)
            if evaluator_object.task.status == Task.STATUS_FAILED:
                print_output(f"entity_evaluator:run_test_evaluation_with_query: status = failed: error:", evaluator_object.task.errors)
            self.assertFalse(evaluator_object.task.status == Task.STATUS_FAILED)

        # Check if the document count is in sync with the query
        self.assertEqual(evaluator_object.document_count, 68)
        self.add_cleanup_files(evaluator_id)


    def run_export_import(self, evaluator_id: int):
        """Tests endpoint for model export and import."""

        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)

        eval_type = evaluator_object.evaluation_type

        # retrieve model zip
        url = f"{self.url}{evaluator_id}/export_model/"
        response = self.client.get(url)

        # Post model zip
        import_url = f"{self.url}import_model/"
        response = self.client.post(import_url, data={"file": BytesIO(response.content)})

        print_output(f"entity_evaluator:run_export_import:evaluation_type:{eval_type}:response.data", response.data)

        self.assertEqual(response.status_code, status.HTTP_201_CREATED)

        imported_evaluator_object = EvaluatorObject.objects.get(pk=response.data["id"])
        imported_evaluator_id = response.data["id"]

        # Check if the models and plot files exist.
        resources = imported_evaluator_object.get_resource_paths()
        for path in resources.values():
            file = pathlib.Path(path)
            self.assertTrue(file.exists())

        evaluator_object_json = evaluator_object.to_json()
        imported_evaluator_object_json = imported_evaluator_object.to_json()

        # Check if scores in original and imported model are the same
        for metric in choices.METRICS:
            self.assertEqual(evaluator_object_json[metric], imported_evaluator_object_json[metric])

        self.add_cleanup_files(evaluator_id)
        self.add_cleanup_files(imported_evaluator_id)


    def run_reevaluate(self, evaluator_id: int):
        """Tests endpoint for re-evaluation."""
        url = f"{self.url}{evaluator_id}/reevaluate/"
        payload = {}
        response = self.client.post(url, payload, format="json")
        print_output(f"entity_evaluator:run_reevaluate:response.data", response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)


    def run_delete(self, evaluator_id: int):
        """Test deleting evaluator and its resources."""

        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
        resources = evaluator_object.get_resource_paths()

        url = f"{self.url}{evaluator_id}/"
        response = self.client.delete(url, format="json")
        print_output(f"entity_evaluator:run_delete:delete:response.data", response.data)

        self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)

        response = self.client.get(url, format="json")
        print_output(f"entity_evaluator:run_delete:get:response.data", response.data)

        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

        # Check if additional files get deleted
        for path in resources.values():
            file = pathlib.Path(path)
            self.assertFalse(file.exists())


    def run_patch(self, evaluator_id: int):
        """Test updating description."""
        url = f"{self.url}{evaluator_id}/"

        payload = {"description": "New description"}

        response = self.client.patch(url, payload, format="json")
        print_output(f"entity_evaluator:run_patch:response.data:", response.data)

        self.assertEqual(response.status_code, status.HTTP_200_OK)
        response = self.client.get(url, format="json")
        self.assertEqual(response.data["description"], "New description")

        self.add_cleanup_files(evaluator_id)
예제 #8
0
class BinaryAndMultilabelEvaluatorObjectViewTests(APITransactionTestCase):
    def setUp(self):
        # Owner of the project
        self.test_index = reindex_test_dataset(from_index=TEST_INDEX_EVALUATOR)
        self.user = create_test_user("EvaluatorOwner", "*****@*****.**", "pw")
        self.project = project_creation("EvaluatorTestProject",
                                        self.test_index, self.user)
        self.project.users.add(self.user)
        self.url = f"{TEST_VERSION_PREFIX}/projects/{self.project.id}/evaluators/"
        self.project_url = f"{TEST_VERSION_PREFIX}/projects/{self.project.id}"

        self.multilabel_avg_functions = choices.MULTILABEL_AVG_FUNCTIONS
        self.binary_avg_functions = choices.BINARY_AVG_FUNCTIONS

        self.multilabel_evaluators = {
            avg: None
            for avg in self.multilabel_avg_functions
        }
        self.binary_evaluators = {
            avg: None
            for avg in self.binary_avg_functions
        }

        self.memory_optimized_multilabel_evaluators = {
            avg: None
            for avg in self.multilabel_avg_functions
        }
        self.memory_optimized_binary_evaluators = {
            avg: None
            for avg in self.binary_avg_functions
        }

        self.true_fact_name = "TRUE_TAG"
        self.pred_fact_name = "PREDICTED_TAG"

        self.true_fact_value = "650 kapital"
        self.pred_fact_value = "650 kuvand"

        self.core_variables_url = f"{TEST_VERSION_PREFIX}/core_variables/5/"

        # Construct a test query
        self.fact_names_to_filter = [self.true_fact_name, self.pred_fact_name]
        self.fact_values_to_filter = [
            "650 bioeetika", "650 rahvusbibliograafiad"
        ]
        self.test_query = Query()
        self.test_query.add_facts_filter(self.fact_names_to_filter,
                                         self.fact_values_to_filter,
                                         operator="must")
        self.test_query = self.test_query.__dict__()

        self.client.login(username="******", password="******")

    def tearDown(self) -> None:
        from texta_elastic.core import ElasticCore
        ElasticCore().delete_index(index=self.test_index, ignore=[400, 404])

    def test(self):

        self.run_test_invalid_fact_name()
        self.run_test_invalid_fact_value()
        self.run_test_invalid_average_function()

        self.run_test_evaluation_with_query()

        self.run_test_binary_evaluation()
        self.run_test_multilabel_evaluation(add_individual_results=True)
        self.run_test_multilabel_evaluation(add_individual_results=False)

        self.run_test_multilabel_evaluation_with_scoring_after_each_scroll(
            add_individual_results=True)
        self.run_test_multilabel_evaluation_with_scoring_after_each_scroll(
            add_individual_results=False)

        self.run_test_individual_results_enabled(
            self.memory_optimized_multilabel_evaluators.values())
        self.run_test_individual_results_enabled(
            self.multilabel_evaluators.values())
        self.run_test_individual_results_disabled(
            self.binary_evaluators.values())

        self.run_test_individual_results_view_multilabel(
            self.multilabel_evaluators["macro"])
        self.run_test_individual_results_view_invalid_input_multilabel(
            self.multilabel_evaluators["macro"])
        self.run_test_individual_results_view_binary(
            self.binary_evaluators["macro"])

        self.run_test_filtered_average_view_multilabel_get(
            self.multilabel_evaluators["macro"])
        self.run_test_filtered_average_view_multilabel_post(
            self.multilabel_evaluators["macro"])
        self.run_test_filtered_average_view_binary(
            self.binary_evaluators["macro"])

        self.run_export_import(self.binary_evaluators["macro"])
        self.run_export_import(self.multilabel_evaluators["macro"])

        self.run_patch(self.binary_evaluators["macro"])
        self.run_reevaluate(self.binary_evaluators["macro"])

        self.run_delete(self.binary_evaluators["macro"])

    def add_cleanup_files(self, evaluator_id: int):
        try:
            evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
        except:
            pass
        if not TEST_KEEP_PLOT_FILES:
            self.addCleanup(remove_file, evaluator_object.plot.path)

    def run_patch(self, evaluator_id: int):
        """Test updating description."""
        url = f"{self.url}{evaluator_id}/"

        payload = {"description": "New description"}

        response = self.client.patch(url, payload, format="json")
        print_output(f"evaluator:run_patch:response.data:", response.data)

        self.assertEqual(response.status_code, status.HTTP_200_OK)
        response = self.client.get(url, format="json")
        self.assertEqual(response.data["description"], "New description")

        self.add_cleanup_files(evaluator_id)

    def run_test_invalid_fact_name(self):
        """
        Check if evaluator endpoint throws an error if one of the
        selected fact names is not present in the selected indices.
        """
        invalid_payloads = [{
            "true_fact": self.true_fact_name,
            "predicted_fact": "INVALID_FACT_NAME"
        }, {
            "true_fact": "INVALID_FACT_NAME",
            "predicted_fact": self.pred_fact_name
        }]
        main_payload = {
            "description": "Test invalid fact name",
            "indices": [{
                "name": self.test_index
            }]
        }
        for invalid_payload in invalid_payloads:
            payload = {**main_payload, **invalid_payload}
            response = self.client.post(self.url, payload, format="json")
            print_output("evaluator:run_test_invalid_fact_name:response.data",
                         response.data)
            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

    def run_test_invalid_fact_value(self):
        """
        Check if evaluator endpoint throws an error if one of the
        selected fact values is not present for the selected fact name.
        """
        invalid_payloads = [{
            "true_fact_value": self.true_fact_value,
            "predicted_fact_value": "INVALID_FACT_NAME"
        }, {
            "true_fact_value": "INVALID_FACT_NAME",
            "predicted_fact_value": self.pred_fact_value
        }]
        main_payload = {
            "description": "Test invalid fact name",
            "indices": [{
                "name": self.test_index
            }],
            "true_fact": self.true_fact_name,
            "predicted_fact": self.pred_fact_name
        }
        for invalid_payload in invalid_payloads:
            payload = {**main_payload, **invalid_payload}
            response = self.client.post(self.url, payload, format="json")
            print_output("evaluator:run_test_invalid_fact_value:response.data",
                         response.data)
            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

    def run_test_invalid_average_function(self):
        """
        Check if evaluator endpoint throws an error if binary average
        function is chosen for multilabel evaluation.
        """

        main_payload = {
            "description": "Test invalid fact name",
            "indices": [{
                "name": self.test_index
            }],
            "true_fact": self.true_fact_name,
            "predicted_fact": self.pred_fact_name,
        }

        invalid_binary_payload = {
            "true_fact_value": self.true_fact_value,
            "predicted_fact_value": self.pred_fact_value,
            "average_function": "samples"
        }

        invalid_multilabel_payload = {"average_function": "binary"}

        invalid_payloads = {
            "binary": invalid_binary_payload,
            "multilabel": invalid_multilabel_payload
        }

        for eval_type, invalid_payload in list(invalid_payloads.items()):
            payload = {**main_payload, **invalid_payload}

            response = self.client.post(self.url,
                                        invalid_payload,
                                        format="json")
            print_output(
                f"evaluator:run_test_invalid_average_function:evaluation_type:{eval_type}:response.data",
                response.data)
            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

    def run_test_individual_results_view_multilabel(self, evaluator_id: int):
        """ Test individual_results endpoint for multilabel evaluators."""

        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
        avg_function = evaluator_object.average_function

        url = f"{self.url}{evaluator_id}/individual_results/"

        default_payload = {}

        response = self.client.post(url, default_payload, format="json")
        print_output(
            f"evaluator:run_test_individual_results_view_multilabel:avg:{avg_function}:default_payload:response.data:",
            response.data)

        self.assertEqual(response.status_code, status.HTTP_200_OK)

        self.assertEqual(response.data["total"],
                         evaluator_object.n_total_classes)

        # Test filtering by count
        payload = {"min_count": 600, "max_count": 630}
        print_output(
            f"evaluator:run_test_individual_results_view_multilabel:avg:{avg_function}:restricted_count:payload:",
            payload)

        response = self.client.post(url, payload, format="json")
        print_output(
            f"evaluator:run_test_individual_results_view_multilabel:avg:{avg_function}:restricted_count:response.data:",
            response.data)

        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertEqual(response.data["total"], 3)

        # Test filtering by precision and accuracy
        payload = {
            "metric_restrictions": {
                "precision": {
                    "min_score": 0.57
                },
                "accuracy": {
                    "min_score": 0.84
                }
            }
        }
        print_output(
            f"evaluator:run_test_individual_results_view_multilabel:avg:{avg_function}:restricted_metrics:payload:",
            payload)

        response = self.client.post(url, payload, format="json")
        print_output(
            f"evaluator:run_test_individual_results_view_multilabel:avg:{avg_function}:restricted_metrics:response.data:",
            response.data)

        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertTrue(response.data["total"] < 10)

        self.add_cleanup_files(evaluator_id)

    def run_test_individual_results_view_binary(self, evaluator_id: int):
        """ Test individual_results endpoint for binary evaluators. """

        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
        evaluation_type = evaluator_object.evaluation_type

        url = f"{self.url}{evaluator_id}/individual_results/"

        default_payload = {}

        response = self.client.post(url, default_payload, format="json")
        print_output(
            f"evaluator:run_test_individual_results_view_binary:avg:{evaluation_type}:default_payload:response.data:",
            response.data)

        # The usage of the endpoint is not available for binary evaluators
        self.assertEqual(response.status_code,
                         status.HTTP_405_METHOD_NOT_ALLOWED)

        self.add_cleanup_files(evaluator_id)

    def run_test_individual_results_view_invalid_input_multilabel(
            self, evaluator_id: int):
        """ Test individual_results endpoint for multilabel evaluators with invalid input. """
        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
        avg_function = evaluator_object.average_function

        url = f"{self.url}{evaluator_id}/individual_results/"

        invalid_payloads = [{
            "metric_restrictions": {
                "asd": {
                    "max_score": 0.5
                }
            }
        }, {
            "metric_restrictions": {
                "precision": 0
            }
        }, {
            "metric_restrictions": {
                "precision": {
                    "asd": 8
                }
            }
        }, {
            "metric_restrictions": {
                "precision": {
                    "min_score": 18
                }
            }
        }, {
            "metric_restrictions": ["asd"]
        }]

        for i, payload in enumerate(invalid_payloads):
            print_output(
                f"evaluator:run_test_individual_results_view_invalid_input_multilabel:avg:{avg_function}:payload:",
                payload)

            response = self.client.post(url, payload, format="json")
            print_output(
                f"evaluator:run_test_individual_results_view_invalid_input_multilabel:avg:{avg_function}:response.data:",
                response.data)

            self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

            self.add_cleanup_files(evaluator_id)

    def run_test_filtered_average_view_binary(self, evaluator_id: int):
        """ Test filtered_average endpoint for binary evaluators. """

        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
        avg_function = evaluator_object.average_function

        url = f"{self.url}{evaluator_id}/filtered_average/"

        default_payload = {}

        response = self.client.post(url, default_payload, format="json")
        print_output(
            f"evaluator:run_test_filtered_average_view_binary:avg:{avg_function}:default_payload:response.data:",
            response.data)

        # The usage of the endpoint is not available for binary evaluators
        self.assertEqual(response.status_code,
                         status.HTTP_405_METHOD_NOT_ALLOWED)

        self.add_cleanup_files(evaluator_id)

    def run_test_filtered_average_view_multilabel_get(self, evaluator_id: int):
        """ Test GET method of filtered_average endpoint for multilabel evaluators. """
        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
        avg_function = evaluator_object.average_function

        url = f"{self.url}{evaluator_id}/filtered_average/"

        response = self.client.get(url, format="json")
        print_output(
            f"evaluator:run_test_filtered_average_view_multilabel_get:avg:{avg_function}:response.data:",
            response.data)

        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertEqual(response.data["count"],
                         evaluator_object.n_total_classes)
        for metric in choices.METRICS:
            self.assertTrue(response.data[metric] > 0)

        self.add_cleanup_files(evaluator_id)

    def run_test_filtered_average_view_multilabel_post(self,
                                                       evaluator_id: int):
        """ Test POST method of filtered_average endpoint for multilabel evaluators. """
        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
        avg_function = evaluator_object.average_function

        url = f"{self.url}{evaluator_id}/filtered_average/"

        payload = {"min_count": 600}

        print_output(
            f"evaluator:run_test_filtered_average_view_multilabel_post:avg:{avg_function}:payload:",
            payload)

        response = self.client.post(url, payload, format="json")
        print_output(
            f"evaluator:run_test_filtered_average_view_multilabel_post:avg:{avg_function}:response.data:",
            response.data)

        self.assertEqual(response.status_code, status.HTTP_200_OK)
        self.assertEqual(response.data["count"], 4)

        for metric in choices.METRICS:
            self.assertTrue(response.data[metric] > 0)

        self.add_cleanup_files(evaluator_id)

    def run_test_individual_results_enabled(self, evaluator_ids: List[int]):
        """
        Test if individual results stored in multilabel evaluators are
        containing correct information.
        """

        for evaluator_id in evaluator_ids:

            evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)

            individual_results = json.loads(
                evaluator_object.individual_results)

            evaluation_type = evaluator_object.evaluation_type
            memory_optimized = evaluator_object.score_after_scroll
            avg_function = evaluator_object.average_function

            print_output(
                f"evaluator:run_test_individual_results_enabled:{evaluation_type}:{avg_function}:memory_optimized:{memory_optimized}:response.data",
                individual_results)

            # Check if individual results exist for all the classes
            self.assertEqual(evaluator_object.n_total_classes,
                             len(individual_results))

            for label, scores in list(individual_results.items()):
                for metric in choices.METRICS:
                    self.assertTrue(scores[metric] > 0)

                cm = np.array(scores["confusion_matrix"])
                cm_size = np.shape(cm)
                # Check if confusion matrix has non-zero values
                self.assertTrue(cm.any())
                # Check if confusion matric has the correct shape
                self.assertEqual(cm_size[0], 2)
                self.assertEqual(cm_size[1], 2)

    def run_test_individual_results_disabled(self, evaluator_ids: List[int]):
        """
        Test if individual results information is not stored in the evaluator
        if add_individual_results was set false.
        """

        for evaluator_id in evaluator_ids:
            evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)

            individual_results = json.loads(
                evaluator_object.individual_results)

            evaluation_type = evaluator_object.evaluation_type
            memory_optimized = evaluator_object.score_after_scroll
            avg_function = evaluator_object.average_function

            print_output(
                f"evaluator:run_test_individual_results_disabled:type:{evaluation_type}:avg:{avg_function}:memory_optimized:{memory_optimized}:response.data",
                individual_results)

            # Check if individual results is empty
            self.assertEqual(0, len(individual_results))

    def run_test_binary_evaluation(self):
        """ Test binary evaluation with averaging functions set in self.binary_avg_functions."""

        main_payload = {
            "description": "Test binary evaluation",
            "indices": [{
                "name": self.test_index
            }],
            "true_fact": self.true_fact_name,
            "predicted_fact": self.pred_fact_name,
            "true_fact_value": self.true_fact_value,
            "predicted_fact_value": self.pred_fact_value,
            "scroll_size": 500,
            "add_individual_results": False
        }

        expected_scores = {
            "weighted": {
                "accuracy": 0.66,
                "precision": 0.68,
                "recall": 0.66,
                "f1_score": 0.67
            },
            "micro": {
                "accuracy": 0.66,
                "precision": 0.66,
                "recall": 0.66,
                "f1_score": 0.66
            },
            "macro": {
                "accuracy": 0.66,
                "precision": 0.49,
                "recall": 0.49,
                "f1_score": 0.49
            },
            "binary": {
                "accuracy": 0.66,
                "precision": 0.18,
                "recall": 0.21,
                "f1_score": 0.19
            }
        }

        for avg_function in self.binary_avg_functions:
            avg_function_payload = {"average_function": avg_function}
            payload = {**main_payload, **avg_function_payload}

            response = self.client.post(self.url, payload, format="json")
            print_output(
                f"evaluator:run_test_binary_evaluation:avg:{avg_function}:response.data",
                response.data)

            self.assertEqual(response.status_code, status.HTTP_201_CREATED)

            evaluator_id = response.data["id"]
            evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
            while evaluator_object.task.status != Task.STATUS_COMPLETED:
                print_output(
                    f"evaluator:run_test_binary_evaluation:avg:{avg_function}: waiting for evaluation task to finish, current status:",
                    evaluator_object.task.status)
                sleep(1)

            evaluator_json = evaluator_object.to_json()
            evaluator_json.pop("individual_results")

            print_output(
                f"evaluator:run_test_binary_evaluation_avg_{avg_function}:evaluator_object.json:",
                evaluator_json)

            for metric in choices.METRICS:
                self.assertEqual(round(evaluator_json[metric], 2),
                                 expected_scores[avg_function][metric])

            self.assertEqual(evaluator_object.n_total_classes, 2)
            self.assertEqual(evaluator_object.n_true_classes, 2)
            self.assertEqual(evaluator_object.n_predicted_classes, 2)

            cm = np.array(json.loads(evaluator_object.confusion_matrix))
            cm_size = np.shape(cm)

            self.assertEqual(evaluator_object.n_total_classes, cm_size[0])
            self.assertEqual(evaluator_object.n_total_classes, cm_size[1])

            self.assertEqual(evaluator_object.document_count, 2000)
            self.assertEqual(evaluator_object.add_individual_results,
                             payload["add_individual_results"])
            self.assertEqual(evaluator_object.scores_imprecise, False)
            self.assertEqual(evaluator_object.evaluation_type, "binary")

            self.assertEqual(evaluator_object.average_function, avg_function)

            self.binary_evaluators[avg_function] = evaluator_id

            self.add_cleanup_files(evaluator_id)

    def run_test_evaluation_with_query(self):
        """ Test if running the evaluation with query works. """
        payload = {
            "description": "Test evaluation with query",
            "indices": [{
                "name": self.test_index
            }],
            "true_fact": self.true_fact_name,
            "predicted_fact": self.pred_fact_name,
            "scroll_size": 500,
            "add_individual_results": False,
            "average_function": "macro",
            "query": json.dumps(self.test_query)
        }

        response = self.client.post(self.url, payload, format="json")
        print_output(
            f"evaluator:run_test_evaluation_with_query:avg:{payload['average_function']}:response.data",
            response.data)

        evaluator_id = response.data["id"]
        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)

        while evaluator_object.task.status != Task.STATUS_COMPLETED:
            print_output(
                f"evaluator:run_test_evaluation_with_query:avg:{payload['average_function']}: waiting for evaluation task to finish, current status:",
                evaluator_object.task.status)
            sleep(1)

        # Check if the document count is in sync with the query
        self.assertEqual(evaluator_object.document_count, 83)
        self.add_cleanup_files(evaluator_id)

    def run_test_multilabel_evaluation(self, add_individual_results: bool):
        """ Test multilabvel evaluation with averaging functions set in self.multilabel_avg_functions"""

        main_payload = {
            "description": "Test multilabel evaluation",
            "indices": [{
                "name": self.test_index
            }],
            "true_fact": self.true_fact_name,
            "predicted_fact": self.pred_fact_name,
            "scroll_size": 500,
            "add_individual_results": add_individual_results
        }

        expected_scores = {
            "weighted": {
                "accuracy": 0,
                "precision": 0.57,
                "recall": 0.67,
                "f1_score": 0.62
            },
            "micro": {
                "accuracy": 0,
                "precision": 0.57,
                "recall": 0.67,
                "f1_score": 0.62
            },
            "macro": {
                "accuracy": 0,
                "precision": 0.57,
                "recall": 0.67,
                "f1_score": 0.62
            },
            "samples": {
                "accuracy": 0,
                "precision": 0.55,
                "recall": 0.73,
                "f1_score": 0.61
            }
        }

        for avg_function in self.multilabel_avg_functions:
            avg_function_payload = {"average_function": avg_function}
            payload = {**main_payload, **avg_function_payload}

            response = self.client.post(self.url, payload, format="json")
            print_output(
                f"evaluator:run_test_multilabel_evaluation:avg:{avg_function}:response.data",
                response.data)

            self.assertEqual(response.status_code, status.HTTP_201_CREATED)

            evaluator_id = response.data["id"]
            evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
            while evaluator_object.task.status != Task.STATUS_COMPLETED:
                print_output(
                    f"evaluator:run_test_multilabel_evaluation:avg:{avg_function}: waiting for evaluation task to finish, current status:",
                    evaluator_object.task.status)
                sleep(1)

            evaluator_json = evaluator_object.to_json()
            evaluator_json.pop("individual_results")

            print_output(
                f"evaluator:run_test_multilabel_evaluation:avg:{avg_function}:evaluator_object.json:",
                evaluator_json)
            for metric in choices.METRICS:
                self.assertEqual(round(evaluator_json[metric], 2),
                                 expected_scores[avg_function][metric])

            self.assertEqual(evaluator_object.n_total_classes, 10)
            self.assertEqual(evaluator_object.n_true_classes, 10)
            self.assertEqual(evaluator_object.n_predicted_classes, 10)

            cm = np.array(json.loads(evaluator_object.confusion_matrix))
            cm_size = np.shape(cm)

            self.assertEqual(evaluator_object.n_total_classes, cm_size[0])
            self.assertEqual(evaluator_object.n_total_classes, cm_size[1])

            self.assertEqual(evaluator_object.document_count, 2000)
            self.assertEqual(evaluator_object.add_individual_results,
                             add_individual_results)
            self.assertEqual(evaluator_object.scores_imprecise, False)
            self.assertEqual(evaluator_object.evaluation_type, "multilabel")
            self.assertEqual(evaluator_object.average_function, avg_function)

            if add_individual_results:
                self.assertEqual(
                    len(json.loads(evaluator_object.individual_results)),
                    evaluator_object.n_total_classes)
                self.multilabel_evaluators[avg_function] = evaluator_id
            else:
                self.assertEqual(
                    len(json.loads(evaluator_object.individual_results)), 0)

            self.add_cleanup_files(evaluator_id)

    def run_test_multilabel_evaluation_with_scoring_after_each_scroll(
            self, add_individual_results: bool):
        """
        Test multilabel evaluation with averaging functions set in self.multilabel_avg_functions and
        calculating and averaging scores after each scroll.
        """

        # Set required memory buffer high
        set_core_setting("TEXTA_EVALUATOR_MEMORY_BUFFER_GB", "100")

        main_payload = {
            "description": "Test Multilabel Evaluator",
            "indices": [{
                "name": self.test_index
            }],
            "true_fact": self.true_fact_name,
            "predicted_fact": self.pred_fact_name,
            "scroll_size": 500,
            "add_individual_results": add_individual_results,
        }

        for avg_function in self.multilabel_avg_functions:
            avg_function_payload = {"average_function": avg_function}
            payload = {**main_payload, **avg_function_payload}

            response = self.client.post(self.url, payload, format="json")
            print_output(
                f"evaluator:run_test_multilabel_evaluation_with_scoring_after_each_scroll:avg:{avg_function}:response.data",
                response.data)

            self.assertEqual(response.status_code, status.HTTP_201_CREATED)

            evaluator_id = response.data["id"]
            evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
            while evaluator_object.task.status != Task.STATUS_COMPLETED:
                print_output(
                    f"evaluator:run_test_multilabel_evaluation_with_scoring_after_each_scroll:avg:{avg_function}: waiting for evaluation task to finish, current status:",
                    evaluator_object.task.status)
                sleep(1)

            evaluator_json = evaluator_object.to_json()
            evaluator_json.pop("individual_results")

            print_output(
                f"evaluator:run_test_multilabel_evaluation_with_scoring_after_each_scroll:avg:{avg_function}:evaluator_object.json:",
                evaluator_json)
            for metric in choices.METRICS:
                if metric == "accuracy":
                    self.assertEqual(evaluator_json[metric], 0)
                else:
                    self.assertTrue(0.5 <= evaluator_json[metric] <= 0.8)

            self.assertEqual(evaluator_object.n_total_classes, 10)
            self.assertEqual(evaluator_object.n_true_classes, 10)
            self.assertEqual(evaluator_object.n_predicted_classes, 10)

            cm = np.array(json.loads(evaluator_object.confusion_matrix))
            cm_size = np.shape(cm)

            self.assertEqual(evaluator_object.n_total_classes, cm_size[0])
            self.assertEqual(evaluator_object.n_total_classes, cm_size[1])

            scores_imprecise = True if avg_function != "micro" else False

            self.assertEqual(evaluator_object.document_count, 2000)
            self.assertEqual(evaluator_object.add_individual_results,
                             add_individual_results)
            self.assertEqual(evaluator_object.scores_imprecise,
                             scores_imprecise)
            self.assertEqual(evaluator_object.evaluation_type, "multilabel")
            self.assertEqual(evaluator_object.average_function, avg_function)

            if add_individual_results:
                self.assertEqual(
                    len(json.loads(evaluator_object.individual_results)),
                    evaluator_object.n_total_classes)
                self.memory_optimized_multilabel_evaluators[
                    avg_function] = evaluator_id
            else:
                self.assertEqual(
                    len(json.loads(evaluator_object.individual_results)), 0)

            self.add_cleanup_files(evaluator_id)

        # Set memory buffer back to default
        set_core_setting("TEXTA_EVALUATOR_MEMORY_BUFFER_GB", "")

    def run_export_import(self, evaluator_id: int):
        """Tests endpoint for model export and import."""

        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)

        eval_type = evaluator_object.evaluation_type
        avg_function = evaluator_object.average_function

        # retrieve model zip
        url = f"{self.url}{evaluator_id}/export_model/"
        response = self.client.get(url)

        # Post model zip
        import_url = f"{self.url}import_model/"
        response = self.client.post(import_url,
                                    data={"file": BytesIO(response.content)})

        print_output(
            f"evaluator:run_export_import:avg:{avg_function}:evaluation_type:{eval_type}:response.data",
            response.data)

        self.assertEqual(response.status_code, status.HTTP_201_CREATED)

        imported_evaluator_object = EvaluatorObject.objects.get(
            pk=response.data["id"])
        imported_evaluator_id = response.data["id"]

        # Check if the models and plot files exist.
        resources = imported_evaluator_object.get_resource_paths()
        for path in resources.values():
            file = pathlib.Path(path)
            self.assertTrue(file.exists())

        evaluator_object_json = evaluator_object.to_json()
        imported_evaluator_object_json = imported_evaluator_object.to_json()

        # Check if scores in original and imported model are the same
        for metric in choices.METRICS:
            self.assertEqual(evaluator_object_json[metric],
                             imported_evaluator_object_json[metric])

        # Check that the sizes of individual labels are the same
        self.assertEqual(
            len(json.loads(evaluator_object.individual_results)),
            len(json.loads(imported_evaluator_object.individual_results)))

        self.add_cleanup_files(evaluator_id)
        self.add_cleanup_files(imported_evaluator_id)

    def run_reevaluate(self, evaluator_id: int):
        """Tests endpoint for re-evaluation."""
        url = f"{self.url}{evaluator_id}/reevaluate/"
        payload = {}
        response = self.client.post(url, payload, format="json")
        print_output(f"evaluator:run_reevaluate:response.data", response.data)
        self.assertEqual(response.status_code, status.HTTP_200_OK)

    def run_delete(self, evaluator_id: int):
        """Test deleting evaluator and its resources."""

        evaluator_object = EvaluatorObject.objects.get(pk=evaluator_id)
        resources = evaluator_object.get_resource_paths()

        url = f"{self.url}{evaluator_id}/"
        response = self.client.delete(url, format="json")
        print_output(f"evaluator:run_delete:delete:response.data",
                     response.data)

        self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)

        response = self.client.get(url, format="json")
        print_output(f"evaluator:run_delete:get:response.data", response.data)

        self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

        # Check if additional files get deleted
        for path in resources.values():
            file = pathlib.Path(path)
            self.assertFalse(file.exists())