コード例 #1
0
    def metrics(self):
        metrics = []

        if self.config.predict_target:
            metrics.append(
                F1Metric(
                    prefix=const.TARGET_TAGS,
                    target_name=const.TARGET_TAGS,
                    PAD=const.PAD_TAGS_ID,
                    labels=const.LABELS,
                )
            )
            metrics.append(
                CorrectMetric(
                    prefix=const.TARGET_TAGS,
                    target_name=const.TARGET_TAGS,
                    PAD=const.PAD_TAGS_ID,
                )
            )
        if self.config.predict_source:
            metrics.append(
                F1Metric(
                    prefix=const.SOURCE_TAGS,
                    target_name=const.SOURCE_TAGS,
                    PAD=const.PAD_TAGS_ID,
                    labels=const.LABELS,
                )
            )
            metrics.append(
                CorrectMetric(
                    prefix=const.SOURCE_TAGS,
                    target_name=const.SOURCE_TAGS,
                    PAD=const.PAD_TAGS_ID,
                )
            )
        if self.config.predict_gaps:
            metrics.append(
                F1Metric(
                    prefix=const.GAP_TAGS,
                    target_name=const.GAP_TAGS,
                    PAD=const.PAD_TAGS_ID,
                    labels=const.LABELS,
                )
            )
            metrics.append(
                CorrectMetric(
                    prefix=const.GAP_TAGS,
                    target_name=const.GAP_TAGS,
                    PAD=const.PAD_TAGS_ID,
                )
            )

        metrics.append(LogMetric(targets=[(const.LOSS, const.LOSS)]))

        return metrics
コード例 #2
0
    def metrics(self):
        metrics = []

        if self.config.predict_target:
            metrics.append(
                F1Metric(
                    prefix=const.TARGET_TAGS,
                    target_name=const.TARGET_TAGS,
                    PAD=const.PAD_TAGS_ID,
                    labels=const.LABELS,
                ))
            metrics.append(
                ThresholdCalibrationMetric(
                    prefix=const.TARGET_TAGS,
                    target_name=const.TARGET_TAGS,
                    PAD=const.PAD_TAGS_ID,
                ))
            metrics.append(
                CorrectMetric(
                    prefix=const.TARGET_TAGS,
                    target_name=const.TARGET_TAGS,
                    PAD=const.PAD_TAGS_ID,
                ))

        if self.config.predict_source:
            metrics.append(
                F1Metric(
                    prefix=const.SOURCE_TAGS,
                    target_name=const.SOURCE_TAGS,
                    PAD=const.PAD_TAGS_ID,
                    labels=const.LABELS,
                ))
            metrics.append(
                CorrectMetric(
                    prefix=const.SOURCE_TAGS,
                    target_name=const.SOURCE_TAGS,
                    PAD=const.PAD_TAGS_ID,
                ))
        if self.config.predict_gaps:
            metrics.append(
                F1Metric(
                    prefix=const.GAP_TAGS,
                    target_name=const.GAP_TAGS,
                    PAD=const.PAD_TAGS_ID,
                    labels=const.LABELS,
                ))
            metrics.append(
                CorrectMetric(
                    prefix=const.GAP_TAGS,
                    target_name=const.GAP_TAGS,
                    PAD=const.PAD_TAGS_ID,
                ))

        if self.config.sentence_level:
            metrics.append(RMSEMetric(target_name=const.SENTENCE_SCORES))
            metrics.append(PearsonMetric(target_name=const.SENTENCE_SCORES))
            metrics.append(SpearmanMetric(target_name=const.SENTENCE_SCORES))
            if self.config.sentence_ll:
                metrics.append(
                    LogMetric(targets=[('model_out', const.SENT_SIGMA)]))
        if self.config.binary_level:
            metrics.append(
                CorrectMetric(prefix=const.BINARY, target_name=const.BINARY))
        if self.config.token_level and self.predictor_tgt is not None:
            metrics.append(
                CorrectMetric(
                    prefix=const.PE,
                    target_name=const.PE,
                    PAD=const.PAD_ID,
                    STOP=const.STOP_ID,
                ))
            metrics.append(
                ExpectedErrorMetric(
                    prefix=const.PE,
                    target_name=const.PE,
                    PAD=const.PAD_ID,
                    STOP=const.STOP_ID,
                ))
            metrics.append(
                PerplexityMetric(
                    prefix=const.PE,
                    target_name=const.PE,
                    PAD=const.PAD_ID,
                    STOP=const.STOP_ID,
                ))
        if self.config.token_level and self.predictor_src is not None:
            metrics.append(
                CorrectMetric(
                    prefix=const.SOURCE,
                    target_name=const.SOURCE,
                    PAD=const.PAD_ID,
                    STOP=const.STOP_ID,
                ))
            metrics.append(
                ExpectedErrorMetric(
                    prefix=const.SOURCE,
                    target_name=const.SOURCE,
                    PAD=const.PAD_ID,
                    STOP=const.STOP_ID,
                ))
            metrics.append(
                PerplexityMetric(
                    prefix=const.SOURCE,
                    target_name=const.SOURCE,
                    PAD=const.PAD_ID,
                    STOP=const.STOP_ID,
                ))
        metrics.append(
            TokenMetric(target_name=const.TARGET,
                        STOP=const.STOP_ID,
                        PAD=const.PAD_ID))
        return metrics
コード例 #3
0
 def metrics(self):
     return (LogMetric(log_targets=[(const.LOSS, const.LOSS)]), )