def process(self, extracts: types.Extracts) -> Iterable[Any]: start_time = datetime.datetime.now() self._evaluate_num_instances.inc(1) # Any combiner_inputs that are set to None will have the default # StandardMetricInputs passed to the combiner's add_input method. Note # that for efficiency a single StandardMetricInputs type is created that has # an include_filter that is a merger of the include_filter values for all # StandardMetricInputsProcessors used by all metrics. This avoids processing # extracts more than once, but does mean metrics may contain # StandardMetricInputs with keys that are not part of their preprocessing # filters. combiner_inputs = [] standard_preprocessors = [] added_default_standard_preprocessor = False for computation in self._computations: if computation.preprocessor is None: # In this case, the combiner is requesting to be passed the default # StandardMetricInputs (i.e. labels, predictions, and example weights). combiner_inputs.append(None) if not added_default_standard_preprocessor: standard_preprocessors.append( metric_types.StandardMetricInputsPreprocessor()) added_default_standard_preprocessor = True elif (type(computation.preprocessor) == # pylint: disable=unidiomatic-typecheck metric_types.StandardMetricInputsPreprocessor): # In this case a custom filter was used, but it is still part of the # StandardMetricInputs. This will be merged into a single preprocessor # for efficiency later, but we still use None to indicate that the # shared StandardMetricInputs value should be passed to the combiner. combiner_inputs.append(None) standard_preprocessors.append(computation.preprocessor) else: combiner_inputs.append( next(computation.preprocessor.process(extracts))) output = { constants.SLICE_KEY_TYPES_KEY: extracts[constants.SLICE_KEY_TYPES_KEY], _COMBINER_INPUTS_KEY: combiner_inputs } if standard_preprocessors: preprocessor = metric_types.StandardMetricInputsPreprocessorList( standard_preprocessors) extracts = copy.copy(extracts) preprocessor.process(extracts) default_combiner_input = metric_util.to_standard_metric_inputs( extracts, include_features=(constants.FEATURES_KEY in preprocessor.include_filter), include_transformed_features=( constants.TRANSFORMED_FEATURES_KEY in preprocessor.include_filter), include_attributions=(constants.ATTRIBUTIONS_KEY in preprocessor.include_filter)) output[_DEFAULT_COMBINER_INPUT_KEY] = default_combiner_input yield output self._timer.update( int((datetime.datetime.now() - start_time).total_seconds()))
def testMergeAccumulators(self): computation = tf_metric_wrapper.tf_metric_computations( [tf.keras.metrics.MeanSquaredError(name='mse')], desired_batch_size=2)[0] example1 = {'labels': [0.0], 'predictions': [0.0], 'example_weights': [1.0]} example2 = {'labels': [0.0], 'predictions': [0.5], 'example_weights': [1.0]} example3 = {'labels': [1.0], 'predictions': [0.3], 'example_weights': [1.0]} example4 = {'labels': [1.0], 'predictions': [0.9], 'example_weights': [1.0]} example5 = {'labels': [1.0], 'predictions': [0.5], 'example_weights': [0.0]} computation.combiner.setup() combiner_inputs = [] for e in (example1, example2, example3, example4, example5): combiner_inputs.append(metric_util.to_standard_metric_inputs(e)) acc1 = computation.combiner.create_accumulator() acc1 = computation.combiner.add_input(acc1, combiner_inputs[0]) acc1 = computation.combiner.add_input(acc1, combiner_inputs[1]) acc1 = computation.combiner.add_input(acc1, combiner_inputs[2]) acc2 = computation.combiner.create_accumulator() acc2 = computation.combiner.add_input(acc2, combiner_inputs[3]) acc2 = computation.combiner.add_input(acc2, combiner_inputs[4]) acc = computation.combiner.merge_accumulators([acc1, acc2]) got_metrics = computation.combiner.extract_output(acc) mse_key = metric_types.MetricKey(name='mse') self.assertDictElementsAlmostEqual(got_metrics, {mse_key: 0.1875})
def process( self, extracts: Union[types.Extracts, List[types.Extracts]]) -> Iterable[Any]: start_time = datetime.datetime.now() self._evaluate_num_instances.inc(1) # Assume multiple extracts (i.e. query key used) and reset after if only one list_of_extracts = extracts if not isinstance(extracts, list): list_of_extracts = [extracts] use_default_combiner_input = None features = None combiner_inputs = [] for computation in self._computations: if computation.preprocessor is None: combiner_inputs.append(None) use_default_combiner_input = True elif isinstance(computation.preprocessor, metric_types.FeaturePreprocessor): if features is None: features = [{} for i in range(len(list_of_extracts))] for i, e in enumerate(list_of_extracts): for v in computation.preprocessor.process(e): features[i].update(v) combiner_inputs.append(None) use_default_combiner_input = True else: combiner_inputs.append( next(computation.preprocessor.process(extracts))) output = {} # Merge the keys for all extracts together. slice_key_types = {} for e in list_of_extracts: for s in e[constants.SLICE_KEY_TYPES_KEY]: slice_key_types[s] = True output[constants.SLICE_KEY_TYPES_KEY] = list(slice_key_types.keys()) output[_COMBINER_INPUTS_KEY] = combiner_inputs if use_default_combiner_input: default_combiner_input = [] for i, e in enumerate(list_of_extracts): if features is not None: e = copy.copy(e) e.update({constants.FEATURES_KEY: features[i]}) # pytype: disable=attribute-error default_combiner_input.append( metric_util.to_standard_metric_inputs( e, include_features=features is not None)) if not isinstance(extracts, list): # Not a list, reset to single StandardMetricInput value default_combiner_input = default_combiner_input[0] output[_DEFAULT_COMBINER_INPUT_KEY] = default_combiner_input yield output self._timer.update( int((datetime.datetime.now() - start_time).total_seconds()))
def process(self, extracts: types.Extracts) -> Iterable[Any]: start_time = datetime.datetime.now() self._evaluate_num_instances.inc(1) use_default_combiner_input = None features = None combiner_inputs = [] for computation in self._computations: if computation.preprocessor is None: combiner_inputs.append(None) use_default_combiner_input = True elif isinstance(computation.preprocessor, metric_types.FeaturePreprocessor): if features is None: features = {} for v in computation.preprocessor.process(extracts): features.update(v) combiner_inputs.append(None) use_default_combiner_input = True else: combiner_inputs.append( next(computation.preprocessor.process(extracts))) output = { constants.SLICE_KEY_TYPES_KEY: extracts[constants.SLICE_KEY_TYPES_KEY], _COMBINER_INPUTS_KEY: combiner_inputs } if use_default_combiner_input: default_combiner_input = [] if features is not None: extracts = copy.copy(extracts) extracts.update({constants.FEATURES_KEY: features}) default_combiner_input = metric_util.to_standard_metric_inputs( extracts, include_features=features is not None) output[_DEFAULT_COMBINER_INPUT_KEY] = default_combiner_input yield output self._timer.update( int((datetime.datetime.now() - start_time).total_seconds()))
def to_standard_metric_inputs_list(list_of_extracts): return [ metric_util.to_standard_metric_inputs(e, True) for e in list_of_extracts ]