Beispiel #1
0
def load_and_deserialize_attributions(
    output_path: Text,
    output_file_format: Text = '',
    slice_specs: Optional[Iterable[slicer.SingleSliceSpec]] = None
) -> Iterator[metrics_for_slice_pb2.AttributionsForSlice]:
    """Read and deserialize the AttributionsForSlice records.

  Args:
    output_path: Path or pattern to search for attribution files under. If a
      directory is passed, files matching 'attributions*' will be searched for.
    output_file_format: Optional file extension to filter files by.
    slice_specs: A set of SingleSliceSpecs to use for filtering returned
      attributions. The attributions for a given slice key will be returned if
      that slice key matches any of the slice_specs.

  Yields:
    AttributionsForSlice protos found in matching files.
  """
    if tf.io.gfile.isdir(output_path):
        output_path = os.path.join(output_path, constants.ATTRIBUTIONS_KEY)
    pattern = _match_all_files(output_path)
    if output_file_format:
        pattern = pattern + '.' + output_file_format
    paths = tf.io.gfile.glob(pattern)
    for value in _raw_value_iterator(paths, output_file_format):
        attributions = metrics_for_slice_pb2.AttributionsForSlice.FromString(
            value)
        if slice_specs and not slicer.slice_key_matches_slice_specs(
                slicer.deserialize_slice_key(attributions.slice_key),
                slice_specs):
            continue
        yield attributions
def load_and_deserialize_attributions(
    output_path: str,
    output_file_format: str = _TFRECORD_FORMAT,
    slice_specs: Optional[Iterable[slicer.SingleSliceSpec]] = None
) -> Iterator[metrics_for_slice_pb2.AttributionsForSlice]:
  """Read and deserialize the AttributionsForSlice records.

  Args:
    output_path: Path or pattern to search for attribution files under. If a
      directory is passed, files matching 'attributions*' will be searched for.
    output_file_format: Optional file extension to filter files by and the
      format to use for parsing. The default value is tfrecord.
    slice_specs: A set of SingleSliceSpecs to use for filtering returned
      attributions. The attributions for a given slice key will be returned if
      that slice key matches any of the slice_specs.

  Yields:
    AttributionsForSlice protos found in matching files.
  """
  if tf.io.gfile.isdir(output_path):
    output_path = os.path.join(output_path, constants.ATTRIBUTIONS_KEY)
  pattern = _match_all_files(output_path)
  if output_file_format:
    pattern = pattern + '.' + output_file_format
  paths = tf.io.gfile.glob(pattern)
  if not paths:
    # For backwards compatibility, check for files without an explicit suffix,
    # but still use the output_file_format for parsing.
    no_suffix_pattern = _match_all_files(output_path)
    paths = tf.io.gfile.glob(no_suffix_pattern)
  for value in _raw_value_iterator(paths, output_file_format):
    attributions = metrics_for_slice_pb2.AttributionsForSlice.FromString(value)
    if slice_specs and not slicer.slice_key_matches_slice_specs(
        slicer.deserialize_slice_key(attributions.slice_key), slice_specs):
      continue
    yield attributions
Beispiel #3
0
 def testSliceKeyMatchesSliceSpecs(self, slice_key, slice_specs,
                                   expected_result):
     self.assertEqual(
         expected_result,
         slicer.slice_key_matches_slice_specs(slice_key, slice_specs))