Ejemplo n.º 1
0
    def process(self, element: types.Extracts) -> List[types.Extracts]:
        fpl = element.get(constants.FEATURES_PREDICTIONS_LABELS_KEY)
        if not fpl:
            raise RuntimeError(
                'FPL missing, Please ensure Predict() was called.')
        if not isinstance(fpl, types.FeaturesPredictionsLabels):
            raise TypeError(
                'Expected FPL to be instance of FeaturesPredictionsLabel. FPL was: '
                '%s of type %s' % (str(fpl), type(fpl)))
        features = fpl.features
        slices = list(
            slicer.get_slices_for_features_dict(features, self._slice_spec))

        # Make a a shallow copy, so we don't mutate the original.
        element_copy = copy.copy(element)

        element_copy[constants.SLICE_KEY_TYPES_KEY] = slices
        # Add a list of stringified slice keys to be materialized to output table.
        if self._materialize:
            element_copy[constants.SLICE_KEYS_KEY] = types.MaterializedColumn(
                name=constants.SLICE_KEYS_KEY,
                value=(list(
                    slicer.stringify_slice_key(x).encode('utf-8')
                    for x in slices)))
        return [element_copy]
Ejemplo n.º 2
0
 def process(self, element: types.Extracts
            ) -> List[Tuple[SliceKeyType, types.Extracts]]:
   key_filter_fn = self._key_filter_fn  # Local cache.
   filtered = {k: v for k, v in element.items() if key_filter_fn(k)}
   result = [(slice_key, filtered)
             for slice_key in element.get(constants.SLICE_KEY_TYPES_KEY)]
   self._num_slices_generated_per_instance.update(len(result))
   self._post_slice_num_instances.inc(len(result))
   return result
def get_fpl_copy(extracts: types.Extracts) -> types.FeaturesPredictionsLabels:
    """Get a copy of the FPL in the extracts of extracts."""
    fpl_orig = extracts.get(constants.FEATURES_PREDICTIONS_LABELS_KEY)
    if not fpl_orig:
        raise RuntimeError('FPL missing, Please ensure _Predict() was called.')

    # We must make a copy of the FPL tuple as well, so that we don't mutate the
    # original which is disallowed by Beam.
    fpl_copy = types.FeaturesPredictionsLabels(
        features=copy.copy(fpl_orig.features),
        labels=fpl_orig.labels,
        predictions=fpl_orig.predictions,
        input_ref=fpl_orig.input_ref)
    return fpl_copy
Ejemplo n.º 4
0
  def process(self, element: types.Extracts
             ) -> Generator[Tuple[SliceKeyType, types.Extracts], None, None]:
    filtered = {}
    for key in element:
      if not self._include_slice_keys_in_output and key in (
          constants.SLICE_KEY_TYPES_KEY, constants.SLICE_KEYS_KEY):
        continue
      filtered[key] = element[key]
    slice_count = 0
    for slice_key in element.get(constants.SLICE_KEY_TYPES_KEY):
      slice_count += 1
      yield (slice_key, filtered)

    self._num_slices_generated_per_instance.update(slice_count)
    self._post_slice_num_instances.inc(slice_count)
Ejemplo n.º 5
0
 def process(
         self, element: types.Extracts
 ) -> List[Tuple[SliceKeyType, types.Extracts]]:
     key_filter_fn = self._key_filter_fn  # Local cache.
     filtered = {k: v for k, v in element.items() if key_filter_fn(k)}
     slice_keys = element.get(constants.SLICE_KEY_TYPES_KEY)
     # The query based evaluator will group slices from multiple examples, so we
     # deduplicate to avoid overcounting. Depending on whether the rows within a
     # batch have a variable or fixed length, either a VarLenTensorValue or a 2D
     # np.ndarray will be created.
     if isinstance(slice_keys, types.VarLenTensorValue):
         slice_keys = slice_keys.values
     elif isinstance(slice_keys, np.ndarray) and len(slice_keys.shape) == 2:
         slice_keys = slice_keys.flatten()
     result = [(slice_key, filtered) for slice_key in set(slice_keys)]
     self._num_slices_generated_per_instance.update(len(result))
     self._post_slice_num_instances.inc(len(result))
     return result
Ejemplo n.º 6
0
 def process(
         self, element: types.Extracts
 ) -> List[Tuple[SliceKeyType, types.Extracts]]:
     key_filter_fn = self._key_filter_fn  # Local cache.
     filtered = {k: v for k, v in element.items() if key_filter_fn(k)}
     slice_keys = element.get(constants.SLICE_KEY_TYPES_KEY)
     # The query based evaluator will group slices into a multi-dimentional array
     # with an extra dimension representing the examples matching the query key.
     # We need to flatten and dedup the slice keys.
     if _is_multi_dim_keys(slice_keys):
         arr = np.array(slice_keys)
         unique_keys = set()
         for k in arr.flatten():
             unique_keys.add(k)
         if not unique_keys and arr.shape:
             # If only the empty overall slice is in array, it is removed by flatten
             unique_keys.add(())
         slice_keys = unique_keys
     result = [(slice_key, filtered) for slice_key in slice_keys]
     self._num_slices_generated_per_instance.update(len(result))
     self._post_slice_num_instances.inc(len(result))
     return result