Exemple #1
0
    def _runMetricsPlotsAndValidationsEvaluatorManualActuation(
            self,
            with_confidence_intervals,
            multi_model,
            metrics_specs=None,
            validation=False):
        """Benchmark MetricsPlotsAndValidationsEvaluator "manually"."""
        self._init_model(multi_model, validation)
        if not metrics_specs:
            metrics_specs = self._eval_config.metrics_specs

        extracts = self._readDatasetIntoBatchedExtracts()
        num_examples = sum(
            [e[constants.ARROW_RECORD_BATCH_KEY].num_rows for e in extracts])
        extracts = [self._extract_features_and_labels(e) for e in extracts]

        prediction_do_fn = model_util.ModelSignaturesDoFn(
            eval_config=self._eval_config,
            eval_shared_models=self._eval_shared_models,
            signature_names={
                constants.PREDICTIONS_KEY:
                {name: [None]
                 for name in self._eval_shared_models}
            },
            prefer_dict_outputs=False)
        prediction_do_fn.setup()

        # Have to predict first
        predict_result = []
        for e in extracts:
            predict_result.extend(prediction_do_fn.process(e))

        # Unbatch extracts
        unbatched_extracts = []
        for e in predict_result:
            unbatched_extracts.extend(
                unbatch_extractor._extract_unbatched_inputs(e))  # pylint: disable=protected-access

        # Add global slice key.
        for e in unbatched_extracts:
            e[tfma.SLICE_KEY_TYPES_KEY] = ()

        # Now Evaluate
        inputs_per_accumulator = 1000
        start = time.time()
        for _ in range(_ITERS):
            computations, _, _, _ = (
                # pylint: disable=protected-access
                metrics_plots_and_validations_evaluator.
                _filter_and_separate_computations(
                    metric_specs_util.to_computations(
                        metrics_specs, eval_config=self._eval_config)))
            # pylint: enable=protected-access

            processed = []
            for elem in unbatched_extracts:
                processed.append(
                    next(
                        metrics_plots_and_validations_evaluator.
                        _PreprocessorDoFn(  # pylint: disable=protected-access
                            computations).process(elem)))

            combiner = metrics_plots_and_validations_evaluator._ComputationsCombineFn(  # pylint: disable=protected-access
                computations=computations)
            if with_confidence_intervals:
                combiner = poisson_bootstrap._BootstrapCombineFn(combiner)  # pylint: disable=protected-access
            combiner.setup()

            accumulators = []
            for batch in benchmark_utils.batched_iterator(
                    processed, inputs_per_accumulator):
                accumulator = combiner.create_accumulator()
                for elem in batch:
                    accumulator = combiner.add_input(accumulator, elem)
                accumulators.append(accumulator)

            final_accumulator = combiner.merge_accumulators(accumulators)
            final_output = combiner.extract_output(final_accumulator)
        end = time.time()
        delta = end - start

        # Sanity check the example count. This is not timed.
        example_count_key = metric_types.MetricKey(
            name="example_count",
            model_name="candidate" if multi_model else "")
        if example_count_key in final_output:
            example_count = final_output[example_count_key]
        else:
            raise ValueError(
                "example_count_key ({}) was not in the final list of "
                "metrics. metrics were: {}".format(example_count_key,
                                                   final_output))

        if with_confidence_intervals:
            # If we're computing using confidence intervals, the example count will
            # not be exact.
            lower_bound = int(0.9 * num_examples)
            upper_bound = int(1.1 * num_examples)
            if example_count < lower_bound or example_count > upper_bound:
                raise ValueError("example count out of bounds: expecting "
                                 "%d < example_count < %d, but got %d" %
                                 (lower_bound, upper_bound, example_count))
        else:
            # If we're not using confidence intervals, we expect the example count to
            # be exact.
            if example_count != num_examples:
                raise ValueError(
                    "example count mismatch: expecting %d got %d" %
                    (num_examples, example_count))

        self.report_benchmark(iters=_ITERS,
                              wall_time=delta,
                              extras={
                                  "inputs_per_accumulator":
                                  inputs_per_accumulator,
                                  "num_examples": num_examples
                              })
Exemple #2
0
    def _runMetricsAndPlotsEvaluatorManualActuation(self,
                                                    with_confidence_intervals,
                                                    metrics_specs=None):
        """Benchmark MetricsAndPlotsEvaluatorV2 "manually"."""
        self._init_model()
        if not metrics_specs:
            metrics_specs = self._eval_config.metrics_specs

        extracts = self._readDatasetIntoBatchedExtracts()
        num_examples = sum(
            [e[constants.ARROW_RECORD_BATCH_KEY].num_rows for e in extracts])
        extracts = [
            batched_input_extractor._ExtractInputs(e, self._eval_config)  # pylint: disable=protected-access
            for e in extracts
        ]

        prediction_do_fn = batched_predict_extractor_v2._BatchedPredictionDoFn(  # pylint: disable=protected-access
            eval_config=self._eval_config,
            eval_shared_models={"": self._eval_shared_model})
        prediction_do_fn.setup()

        # Have to predict first
        predict_result = []
        for e in extracts:
            predict_result.extend(prediction_do_fn.process(e))

        # Unbatch extracts
        unbatched_extarcts = []
        for e in predict_result:
            unbatched_extarcts.extend(
                unbatch_extractor._ExtractUnbatchedInputs(e))  # pylint: disable=protected-access

        # Add global slice key.
        for e in unbatched_extarcts:
            e[tfma.SLICE_KEY_TYPES_KEY] = ()

        # Now Evaluate
        inputs_per_accumulator = 1000
        start = time.time()

        computations, _ = (
            # pylint: disable=protected-access
            metrics_plots_and_validations_evaluator.
            _filter_and_separate_computations(
                metric_specs.to_computations(metrics_specs,
                                             eval_config=self._eval_config)))
        # pylint: enable=protected-access

        processed = []
        for elem in unbatched_extarcts:
            processed.append(
                next(
                    metrics_plots_and_validations_evaluator._PreprocessorDoFn(  # pylint: disable=protected-access
                        computations).process(elem)))

        combiner = metrics_plots_and_validations_evaluator._ComputationsCombineFn(  # pylint: disable=protected-access
            computations=computations,
            compute_with_sampling=with_confidence_intervals)

        accumulators = []
        for batch in benchmark_utils.batched_iterator(processed,
                                                      inputs_per_accumulator):
            accumulator = combiner.create_accumulator()
            for elem in batch:
                accumulator = combiner.add_input(accumulator, elem)
            accumulators.append(accumulator)

        final_accumulator = combiner.merge_accumulators(accumulators)
        final_output = combiner.extract_output(final_accumulator)
        end = time.time()
        delta = end - start

        # Sanity check the example count. This is not timed.
        example_count_key = metric_types.MetricKey(name="example_count")
        example_count = None
        for x in final_output:
            if example_count_key in x:
                example_count = x[example_count_key]
                break

        if example_count is None:
            raise ValueError(
                "example_count was not in the final list of metrics. "
                "metrics were: %s" % str(final_output))

        if with_confidence_intervals:
            # If we're computing using confidence intervals, the example count will
            # not be exact.
            lower_bound = int(0.9 * num_examples)
            upper_bound = int(1.1 * num_examples)
            if example_count < lower_bound or example_count > upper_bound:
                raise ValueError("example count out of bounds: expecting "
                                 "%d < example_count < %d, but got %d" %
                                 (lower_bound, upper_bound, example_count))
        else:
            # If we're not using confidence intervals, we expect the example count to
            # be exact.
            if example_count != num_examples:
                raise ValueError(
                    "example count mismatch: expecting %d got %d" %
                    (num_examples, example_count))

        self.report_benchmark(iters=1,
                              wall_time=delta,
                              extras={
                                  "inputs_per_accumulator":
                                  inputs_per_accumulator,
                                  "num_examples": num_examples
                              })