Example #1
0
    def benchmarkPredictExtractorManualActuation(self):
        """Benchmark PredictExtractorV2 "manually"."""
        self._init_model()
        records = self._readDatasetIntoExtracts()
        extracts = []
        for elem in records:
            extracts.append(
                input_extractor._ParseExample(elem, self._eval_config))  # pylint: disable=protected-access

        prediction_do_fn = predict_extractor_v2._PredictionDoFn(  # pylint: disable=protected-access
            eval_config=self._eval_config,
            eval_shared_models={"": self._eval_shared_model})
        prediction_do_fn.setup()

        start = time.time()
        predict_result = []
        predict_batch_size = 1000
        for batch in benchmark_utils.batched_iterator(extracts,
                                                      predict_batch_size):
            predict_result.extend(prediction_do_fn.process(batch))

        end = time.time()
        delta = end - start
        self.report_benchmark(iters=1,
                              wall_time=delta,
                              extras={"num_examples": len(records)})
Example #2
0
    def _runMetricsAndPlotsEvaluatorManualActuation(self,
                                                    with_confidence_intervals,
                                                    metrics_specs=None):
        """Benchmark MetricsAndPlotsEvaluatorV2 "manually"."""
        self._init_model()
        if not metrics_specs:
            metrics_specs = self._eval_config.metrics_specs

        records = self._readDatasetIntoExtracts()
        extracts = []
        for elem in records:
            extracts.append(
                input_extractor._ParseExample(elem, self._eval_config))  # pylint: disable=protected-access

        prediction_do_fn = predict_extractor_v2._PredictionDoFn(  # pylint: disable=protected-access
            eval_config=self._eval_config,
            eval_shared_models={"": self._eval_shared_model})
        prediction_do_fn.setup()

        # Have to predict first
        predict_result = []
        predict_batch_size = 1000
        for batch in benchmark_utils.batched_iterator(extracts,
                                                      predict_batch_size):
            predict_result.extend(prediction_do_fn.process(batch))

        # Now Evaluate
        inputs_per_accumulator = 1000
        start = time.time()

        computations, _ = (
            metrics_and_plots_evaluator_v2._filter_and_separate_computations(  # pylint: disable=protected-access
                metric_specs.to_computations(metrics_specs,
                                             eval_config=self._eval_config)))

        processed = []
        for elem in predict_result:
            processed.append(
                next(
                    metrics_and_plots_evaluator_v2._PreprocessorDoFn(  # pylint: disable=protected-access
                        computations).process(elem)))

        combiner = metrics_and_plots_evaluator_v2._ComputationsCombineFn(  # pylint: disable=protected-access
            computations=computations,
            compute_with_sampling=with_confidence_intervals)

        accumulators = []
        for batch in benchmark_utils.batched_iterator(processed,
                                                      inputs_per_accumulator):
            accumulator = combiner.create_accumulator()
            for elem in batch:
                accumulator = combiner.add_input(accumulator, elem)
            accumulators.append(accumulator)

        final_accumulator = combiner.merge_accumulators(accumulators)
        final_output = combiner.extract_output(final_accumulator)
        end = time.time()
        delta = end - start

        # Sanity check the example count. This is not timed.
        example_count_key = metric_types.MetricKey(name="example_count")
        example_count = None
        for x in final_output:
            if example_count_key in x:
                example_count = x[example_count_key]
                break

        if example_count is None:
            raise ValueError(
                "example_count was not in the final list of metrics. "
                "metrics were: %s" % str(final_output))

        if with_confidence_intervals:
            # If we're computing using confidence intervals, the example count will
            # not be exact.
            lower_bound = int(0.9 * len(records))
            upper_bound = int(1.1 * len(records))
            if example_count < lower_bound or example_count > upper_bound:
                raise ValueError("example count out of bounds: expecting "
                                 "%d < example_count < %d, but got %d" %
                                 (lower_bound, upper_bound, example_count))
        else:
            # If we're not using confidence intervals, we expect the example count to
            # be exact.
            if example_count != len(records):
                raise ValueError(
                    "example count mismatch: expecting %d got %d" %
                    (len(records), example_count))

        self.report_benchmark(iters=1,
                              wall_time=delta,
                              extras={
                                  "inputs_per_accumulator":
                                  inputs_per_accumulator,
                                  "num_examples": len(records)
                              })