Exemplo n.º 1
0
    def post(self, request, project_pk: int):
        """
        Returns existing fact names and values from Elasticsearch.
        """
        serializer = ProjectGetFactsSerializer(data=request.data)

        if not serializer.is_valid():
            raise SerializerNotValid(detail=serializer.errors)

        indices = serializer.validated_data["indices"]
        indices = [index["name"] for index in indices]

        # retrieve and validate project indices
        project = get_object_or_404(Project, pk=project_pk)
        self.check_object_permissions(request, project)
        project_indices = project.get_available_or_all_project_indices(indices)  # Gives all if n   one, the default, is entered.

        if not project_indices:
            return Response([])

        vals_per_name = serializer.validated_data['values_per_name']
        include_values = serializer.validated_data['include_values']
        fact_name = serializer.validated_data['fact_name']
        include_doc_path = serializer.validated_data['include_doc_path']
        exclude_zero_spans = serializer.validated_data['exclude_zero_spans']
        mlp_doc_path = serializer.validated_data['mlp_doc_path']

        aggregator = ElasticAggregator(indices=project_indices)

        if mlp_doc_path and exclude_zero_spans:
            # If exclude_zerp_spans is enabled and mlp_doc_path specified, the other values don't have any effect -
            # this behaviour might need to change at some point
            fact_map = aggregator.facts(size=1, include_values=True, include_doc_path=True, exclude_zero_spans=exclude_zero_spans)

        else:
            fact_map = aggregator.facts(size=vals_per_name, include_values=include_values, filter_by_fact_name=fact_name, include_doc_path=include_doc_path, exclude_zero_spans=exclude_zero_spans)

        if fact_name:
            fact_map_list = [v for v in fact_map]

        elif mlp_doc_path and exclude_zero_spans:
            # Return only fact names where doc_path contains mlp_doc_path as a parent field and facts have spans.
            # NB! Doesn't take into account the situation where facts have the same name, but different doc paths! Could happen!
            fact_map_list = [k for k, v in fact_map.items() if v and mlp_doc_path == v[0]["doc_path"].rsplit(".", 1)[0]]

        elif include_values:
            fact_map_list = [{'name': k, 'values': v} for k, v in fact_map.items()]
        else:
            fact_map_list = [v for v in fact_map]
        return Response(fact_map_list, status=status.HTTP_200_OK)
Exemplo n.º 2
0
def validate_pos_label(data):
    """ For Tagger, TorchTagger and BertTagger.
    Checks if the inserted pos label is present in the fact values.
    """

    fact_name = data.get("fact_name")

    # If fact name is not selected, the value for pos label doesn't matter
    if not fact_name:
        return data

    indices = [index.get("name") for index in data.get("indices")]
    pos_label = data.get("pos_label")
    serializer_query = data.get("query")

    try:
        # If query is passed as a JSON string
        query = json.loads(serializer_query)
    except Exception as e:
        # if query is passed as a JSON dict
        query = serializer_query

    ag = ElasticAggregator(indices=indices, query=query)
    fact_values = ag.facts(size=10, filter_by_fact_name=fact_name, include_values=True)

    # If there exists exactly two possible values for the selected fact, check if pos label
    # is selected and if it is present in corresponding fact values.
    if len(fact_values) == 2:
        if not pos_label:
            raise ValidationError(f"The fact values corresponding to the selected query and fact '{fact_name}' are binary. You must specify param 'pos_label' for evaluation purposes. Allowed values for 'pos_label' are: {fact_values}")
        elif pos_label not in fact_values:
            raise ValidationError(f"The specified pos label '{pos_label}' is NOT one of the fact values for fact '{fact_name}'. Please select an existing fact value. Allowed fact values are: {fact_values}")
    return data
Exemplo n.º 3
0
 def _get_tags(self, fact_name, min_count=50, max_count=None, query={}):
     """Finds possible tags for training by aggregating active project's indices."""
     active_indices = self.tagger_object.get_indices()
     es_a = ElasticAggregator(indices=active_indices, query=query)
     # limit size to 10000 unique tags
     tag_values = es_a.facts(filter_by_fact_name=fact_name,
                             min_count=min_count,
                             max_count=max_count,
                             size=10000)
     return tag_values
Exemplo n.º 4
0
 def get_tags(self,
              fact_name,
              active_project,
              min_count=1000,
              max_count=None,
              indices=None):
     """Finds possible tags for training by aggregating active project's indices."""
     active_indices = list(
         active_project.get_indices()) if indices is None else indices
     es_a = ElasticAggregator(indices=active_indices)
     # limit size to 10000 unique tags
     tag_values = es_a.facts(filter_by_fact_name=fact_name,
                             min_count=min_count,
                             max_count=max_count,
                             size=10000)
     return tag_values
Exemplo n.º 5
0
def validate_fact_value(indices: List[str], query: dict, fact: str,
                        fact_value: str):
    """ Check if given fact value exists under given fact. """
    # Fact value is allowed to be empty
    if not fact_value:
        return True

    ag = ElasticAggregator(indices=indices, query=deepcopy(query))

    fact_values = ag.facts(size=choices.DEFAULT_MAX_AGGREGATION_SIZE,
                           filter_by_fact_name=fact,
                           include_values=True)
    if fact_value not in fact_values:
        raise ValidationError(
            f"Fact value '{fact_value}' not in the list of fact values for fact '{fact}'."
        )
    return True
Exemplo n.º 6
0
def evaluate_tags_task(object_id: int,
                       indices: List[str],
                       query: dict,
                       es_timeout: int = 10,
                       scroll_size: int = 100):
    try:
        logging.getLogger(INFO_LOGGER).info(
            f"Starting evaluator task for Evaluator with ID {object_id}.")

        evaluator_object = Evaluator.objects.get(pk=object_id)
        progress = ShowProgress(evaluator_object.task, multiplier=1)

        # Retreieve facts and sklearn average function from the model
        true_fact = evaluator_object.true_fact
        pred_fact = evaluator_object.predicted_fact
        true_fact_value = evaluator_object.true_fact_value
        pred_fact_value = evaluator_object.predicted_fact_value

        average = evaluator_object.average_function
        add_individual_results = evaluator_object.add_individual_results

        searcher = ElasticSearcher(indices=indices,
                                   field_data=["texta_facts"],
                                   query=query,
                                   output=ElasticSearcher.OUT_RAW,
                                   timeout=f"{es_timeout}m",
                                   callback_progress=progress,
                                   scroll_size=scroll_size)

        # Binary
        if true_fact_value and pred_fact_value:
            logging.getLogger(INFO_LOGGER).info(
                f"Starting binary evaluation. Comparing following fact and fact value pairs: TRUE: ({true_fact}: {true_fact_value}), PREDICTED: ({pred_fact}: {pred_fact_value})."
            )

            # Set the evaluation type in the model
            evaluator_object.evaluation_type = "binary"

            true_set = {true_fact_value, "other"}
            pred_set = {pred_fact_value, "other"}

            classes = ["other", true_fact_value]
            n_total_classes = len(classes)

        # Multilabel/multiclass
        else:
            logging.getLogger(INFO_LOGGER).info(
                f"Starting multilabel evaluation. Comparing facts TRUE: '{true_fact}', PRED: '{pred_fact}'."
            )

            # Make deepcopy of the query to avoid modifying Searcher's query.
            es_aggregator = ElasticAggregator(indices=indices,
                                              query=deepcopy(query))

            # Get all fact values corresponding to true and predicted facts to construct total set of labels
            # needed for confusion matrix, individual score calculations and memory imprint calculations
            true_fact_values = es_aggregator.facts(
                size=choices.DEFAULT_MAX_AGGREGATION_SIZE,
                filter_by_fact_name=true_fact)
            pred_fact_values = es_aggregator.facts(
                size=choices.DEFAULT_MAX_AGGREGATION_SIZE,
                filter_by_fact_name=pred_fact)

            true_set = set(true_fact_values)
            pred_set = set(pred_fact_values)

            classes = list(true_set.union(pred_set))
            n_total_classes = len(classes)

            # Add dummy classes for missing labels
            classes.extend(
                [choices.MISSING_TRUE_LABEL, choices.MISSING_PRED_LABEL])

            ## Set the evaluation type in the model
            evaluator_object.evaluation_type = "multilabel"

            classes.sort(key=lambda x: x[0].lower())

        # Get number of documents in the query to estimate memory imprint
        n_docs = searcher.count()
        evaluator_object.task.total = n_docs
        evaluator_object.task.save()

        logging.getLogger(INFO_LOGGER).info(
            f"Number of documents: {n_docs} | Number of classes: {len(classes)}"
        )

        # Get the memory buffer value from core variables
        core_memory_buffer_value_gb = get_core_setting(
            "TEXTA_EVALUATOR_MEMORY_BUFFER_GB")

        # Calculate the value based on given ratio if the core variable is empty
        memory_buffer_gb = calculate_memory_buffer(
            memory_buffer=core_memory_buffer_value_gb,
            ratio=EVALUATOR_MEMORY_BUFFER_RATIO,
            unit="gb")

        required_memory = get_memory_imprint(
            n_docs=n_docs,
            n_classes=len(classes),
            eval_type=evaluator_object.evaluation_type,
            unit="gb",
            int_size=64)
        enough_memory = is_enough_memory_available(
            required_memory=required_memory,
            memory_buffer=memory_buffer_gb,
            unit="gb")

        # Enable scoring after each scroll if there isn't enough memory
        # for calculating the scores for the whole set of documents at once.
        score_after_scroll = False if enough_memory else True

        # If scoring after each scroll is enabled and scores are averaged after each scroll
        # the results for each averaging function besides `micro` are imprecise
        scores_imprecise = True if (score_after_scroll
                                    and average != "micro") else False

        # Store document counts, labels' class counts and indicatior if scores are imprecise
        evaluator_object.document_count = n_docs
        evaluator_object.n_true_classes = len(true_set)
        evaluator_object.n_predicted_classes = len(pred_set)
        evaluator_object.n_total_classes = n_total_classes
        evaluator_object.scores_imprecise = scores_imprecise
        evaluator_object.score_after_scroll = score_after_scroll

        # Save model updates
        evaluator_object.save()

        logging.getLogger(INFO_LOGGER).info(
            f"Enough available memory: {enough_memory} | Score after scroll: {score_after_scroll}"
        )

        # Get number of batches for the logger
        n_batches = math.ceil(n_docs / scroll_size)

        # Scroll and score tags
        scores, bin_scores = scroll_and_score(
            generator=searcher,
            evaluator_object=evaluator_object,
            true_fact=true_fact,
            pred_fact=pred_fact,
            true_fact_value=true_fact_value,
            pred_fact_value=pred_fact_value,
            classes=classes,
            average=average,
            score_after_scroll=score_after_scroll,
            n_batches=n_batches,
            add_individual_results=add_individual_results)

        logging.getLogger(INFO_LOGGER).info(f"Final scores: {scores}")

        for conn in connections.all():
            conn.close_if_unusable_or_obsolete()

        confusion = scores["confusion_matrix"]
        confusion = np.asarray(confusion, dtype="int64")

        if len(classes) <= choices.DEFAULT_MAX_CONFUSION_CLASSES:
            # Delete empty rows and columns corresponding to missing pred/true labels from the confusion matrix
            confusion, classes = delete_empty_rows_and_cols(confusion, classes)

        scores["confusion_matrix"] = confusion.tolist()

        # Generate confusion matrix plot and save it
        image_name = f"{secrets.token_hex(15)}.png"
        evaluator_object.plot.save(image_name,
                                   create_confusion_plot(
                                       scores["confusion_matrix"], classes),
                                   save=False)
        image_path = pathlib.Path(MEDIA_URL) / image_name
        evaluator_object.plot.name = str(image_path)

        # Add final scores to the model
        evaluator_object.precision = scores["precision"]
        evaluator_object.recall = scores["recall"]
        evaluator_object.f1_score = scores["f1_score"]
        evaluator_object.accuracy = scores["accuracy"]
        evaluator_object.confusion_matrix = json.dumps(
            scores["confusion_matrix"])

        evaluator_object.individual_results = json.dumps(
            remove_not_found(bin_scores), ensure_ascii=False)
        evaluator_object.add_misclassified_examples = False

        evaluator_object.save()
        evaluator_object.task.complete()
        return True

    except Exception as e:
        logging.getLogger(ERROR_LOGGER).exception(e)
        error_message = f"{str(e)[:100]}..."  # Take first 100 characters in case the error message is massive.
        evaluator_object.task.add_error(error_message)
        evaluator_object.task.update_status(Task.STATUS_FAILED)