예제 #1
0
    def test(self, ts, steps=0, **kwargs):
        """Method that evaluates on some data.  There are 2 modes this can run in, `feed_dict` and `dataset`

        In `feed_dict` mode, the model cycles the test data batch-wise and feeds each batch in with a `feed_dict`.
        In `dataset` mode, the data is still passed in to this method, but it is not passed in a `feed_dict` and is
        mostly superfluous since the features are grafted right onto the graph.  However, we do use it for supplying
        the ground truth, ids and text, so it is essential that the caller does not shuffle the data
        :param ts: The test set
        :param conll_output: (`str`) An optional file output
        :param txts: A list of text data associated with the encoded batch
        :param dataset: (`bool`) Is this using `tf.dataset`s
        :return: The metrics
        """
        SET_TRAIN_FLAG(False)

        total_correct = total_sum = 0
        gold_spans = []
        pred_spans = []

        self.cm = ConfusionMatrix(self.idx2classlabel)

        handle = None
        if kwargs.get("conll_output") is not None and kwargs.get(
                'txts') is not None:
            handle = open(kwargs.get("conll_output"), "w")

        try:
            pg = create_progress_bar(steps)
            metrics = {}
            for (features, y), batch in pg(
                    zip_longest(ts, kwargs.get('batches', []), fillvalue={})):
                correct, count, golds, guesses = self.process_batch(
                    features,
                    y,
                    handle=handle,
                    txts=kwargs.get("txts"),
                    ids=batch.get("ids"))
                total_correct += correct
                total_sum += count
                gold_spans.extend(golds)
                pred_spans.extend(guesses)

            total_acc = total_correct / float(total_sum)
            # Only show the fscore if requested
            metrics['tagging_f1'] = span_f1(gold_spans, pred_spans)
            metrics['tagging_acc'] = total_acc
            metrics.update({
                f"classification_{k}": v
                for k, v in self.cm.get_all_metrics().items()
            })
            if self.verbose:
                conll_metrics = per_entity_f1(gold_spans, pred_spans)
                conll_metrics['acc'] = total_acc * 100
                conll_metrics['tokens'] = total_sum
                logger.info(conlleval_output(conll_metrics))
        finally:
            if handle is not None:
                handle.close()

        return metrics
예제 #2
0
    def _test(self, ts, **kwargs):

        self.model.eval()
        total_sum = 0
        total_correct = 0

        gold_spans = []
        pred_spans = []
        cm = ConfusionMatrix(self.idx2classlabel)
        metrics = {}
        steps = len(ts)
        conll_output = kwargs.get('conll_output', None)
        txts = kwargs.get('txts', None)
        handle = None
        if conll_output is not None and txts is not None:
            handle = open(conll_output, "w")
        pg = create_progress_bar(steps)
        for batch_dict in pg(ts):

            inputs = self.model.make_input(batch_dict)
            y = inputs.pop('y')
            lengths = inputs['lengths']
            ids = inputs['ids']
            class_labels = inputs["class_label"]
            with torch.no_grad():
                class_pred, pred = self.model(inputs)
            correct, count, golds, guesses = self.process_output(
                pred, y.data, lengths, ids, handle, txts)
            total_correct += correct
            total_sum += count
            gold_spans.extend(golds)
            pred_spans.extend(guesses)
            _add_to_cm(cm, class_labels, class_pred)

        total_acc = total_correct / float(total_sum)
        metrics['tagging_acc'] = total_acc
        metrics['tagging_f1'] = span_f1(gold_spans, pred_spans)
        metrics.update({
            f"classification_{k}": v
            for k, v in cm.get_all_metrics().items()
        })
        if self.verbose:
            # TODO: Add programmatic access to these metrics?
            conll_metrics = per_entity_f1(gold_spans, pred_spans)
            conll_metrics['acc'] = total_acc * 100
            conll_metrics['tokens'] = total_sum.item()
            logger.info(conlleval_output(conll_metrics))
        return metrics
예제 #3
0
    def test(self, ts, conll_output=None, txts=None, dataset=True):
        """Method that evaluates on some data.  There are 2 modes this can run in, `feed_dict` and `dataset`

        In `feed_dict` mode, the model cycles the test data batch-wise and feeds each batch in with a `feed_dict`.
        In `dataset` mode, the data is still passed in to this method, but it is not passed in a `feed_dict` and is
        mostly superfluous since the features are grafted right onto the graph.  However, we do use it for supplying
        the ground truth, ids and text, so it is essential that the caller does not shuffle the data
        :param ts: The test set
        :param conll_output: (`str`) An optional file output
        :param txts: A list of text data associated with the encoded batch
        :param dataset: (`bool`) Is this using `tf.dataset`s
        :return: The metrics
        """
        total_correct = total_sum = 0
        gold_spans = []
        pred_spans = []

        steps = len(ts)
        pg = create_progress_bar(steps)
        metrics = {}
        # Only if they provide a file and the raw txts, we can write CONLL file
        handle = None
        if conll_output is not None and txts is not None:
            handle = open(conll_output, "w")

        try:
            for batch_dict in pg(ts):
                correct, count, golds, guesses = self.process_batch(batch_dict, handle, txts, dataset)
                total_correct += correct
                total_sum += count
                gold_spans.extend(golds)
                pred_spans.extend(guesses)

            total_acc = total_correct / float(total_sum)
            # Only show the fscore if requested
            metrics['f1'] = span_f1(gold_spans, pred_spans)
            metrics['acc'] = total_acc
            if self.verbose:
                conll_metrics = per_entity_f1(gold_spans, pred_spans)
                conll_metrics['acc'] = total_acc * 100
                conll_metrics['tokens'] = total_sum
                logger.info(conlleval_output(conll_metrics))
        finally:
            if handle is not None:
                handle.close()

        return metrics
예제 #4
0
    def test(self, ts, conll_output=None, txts=None):
        """Method that evaluates on some data.

        :param ts: The test set
        :param conll_output: (`str`) An optional file output
        :param txts: A list of text data associated with the encoded batch
        :return: The metrics
        """
        total_correct = total_sum = 0
        gold_spans = []
        pred_spans = []

        steps = len(ts)
        pg = create_progress_bar(steps)
        metrics = {}
        # Only if they provide a file and the raw txts, we can write CONLL file
        handle = None
        if conll_output is not None and txts is not None:
            handle = open(conll_output, "w")

        try:
            for batch_dict in pg(ts):
                correct, count, golds, guesses = self.process_batch(
                    batch_dict, handle, txts)
                total_correct += correct
                total_sum += count
                gold_spans.extend(golds)
                pred_spans.extend(guesses)

            total_acc = total_correct / float(total_sum)
            # Only show the fscore if requested
            metrics['f1'] = span_f1(gold_spans, pred_spans)
            metrics['acc'] = total_acc
            if self.verbose:
                conll_metrics = per_entity_f1(gold_spans, pred_spans)
                conll_metrics['acc'] = total_acc * 100
                conll_metrics['tokens'] = total_sum
                logger.info(conlleval_output(conll_metrics))
        finally:
            if handle is not None:
                handle.close()

        return metrics
예제 #5
0
    def _test(self, ts, **kwargs):

        self.model.eval()
        total_sum = 0
        total_correct = 0

        gold_spans = []
        pred_spans = []

        metrics = {}
        steps = len(ts)
        conll_output = kwargs.get('conll_output', None)
        txts = kwargs.get('txts', None)
        handle = None
        if conll_output is not None and txts is not None:
            handle = open(conll_output, "w")
        pg = create_progress_bar(steps)
        for batch_dict in pg(ts):

            inputs = self.model.make_input(batch_dict)
            y = inputs.pop('y')
            lengths = inputs['lengths']
            ids = inputs['ids']
            with torch.no_grad():
                pred = self.model(inputs)
            correct, count, golds, guesses = self.process_output(
                pred, y.data, lengths, ids, handle, txts)
            total_correct += correct
            total_sum += count
            gold_spans.extend(golds)
            pred_spans.extend(guesses)

        total_acc = total_correct / float(total_sum)
        metrics['acc'] = total_acc
        metrics['f1'] = span_f1(gold_spans, pred_spans)
        if self.verbose:
            # TODO: Add programmatic access to these metrics?
            conll_metrics = per_entity_f1(gold_spans, pred_spans)
            conll_metrics['acc'] = total_acc * 100
            conll_metrics['tokens'] = total_sum.item()
            logger.info(conlleval_output(conll_metrics))
        return metrics