コード例 #1
0
def load_sharded_statistics(
    input_path_prefix: Optional[str] = None,
    input_paths: Optional[Iterable[str]] = None,
    io_provider: Optional[statistics_io_impl.StatisticsIOProvider] = None
) -> DatasetListView:
    """Read a sharded DatasetFeatureStatisticsList from disk as a DatasetListView.

  Args:
    input_path_prefix: If passed, loads files starting with this prefix and
      ending with a pattern corresponding to the output of the provided
        io_provider.
    input_paths: A list of file paths of files containing sharded
      DatasetFeatureStatisticsList protos.
    io_provider: Optional StatisticsIOProvider. If unset, a default will be
      constructed.

  Returns:
    A DatasetListView containing the merged proto.
  """
    if input_path_prefix is None == input_paths is None:
        raise ValueError(
            'Must provide one of input_paths_prefix, input_paths.')
    if io_provider is None:
        io_provider = statistics_io_impl.get_io_provider()
    if input_path_prefix is not None:
        input_paths = io_provider.glob(input_path_prefix)
    acc = statistics.DatasetListAccumulator()
    stats_iter = io_provider.record_iterator_impl(input_paths)
    for stats_list in stats_iter:
        for dataset in stats_list.datasets:
            acc.MergeDatasetFeatureStatistics(dataset.SerializeToString())
    stats = statistics_pb2.DatasetFeatureStatisticsList()
    stats.ParseFromString(acc.Get())
    return DatasetListView(stats)
コード例 #2
0
 def test_load_sharded_pattern(self):
     full_stats_proto = statistics_pb2.DatasetFeatureStatisticsList()
     text_format.Parse(_STATS_PROTO, full_stats_proto)
     tmp_dir = self.create_tempdir()
     tmp_path = os.path.join(tmp_dir, 'statistics-0-of-1')
     writer = tf.compat.v1.io.TFRecordWriter(tmp_path)
     for dataset in full_stats_proto.datasets:
         shard = statistics_pb2.DatasetFeatureStatisticsList()
         shard.datasets.append(dataset)
         writer.write(shard.SerializeToString())
     writer.close()
     view = stats_util.load_sharded_statistics(
         input_path_prefix=tmp_path.rstrip('-0-of-1'),
         io_provider=statistics_io_impl.get_io_provider('tfrecords'))
     compare.assertProtoEqual(self, view.proto(), full_stats_proto)
コード例 #3
0
    def test_write_and_read_records(self):
        datasets = [
            statistics_pb2.DatasetFeatureStatisticsList(
                datasets=[statistics_pb2.DatasetFeatureStatistics(name='d1')]),
            statistics_pb2.DatasetFeatureStatisticsList(
                datasets=[statistics_pb2.DatasetFeatureStatistics(name='d2')])
        ]
        output_prefix = tempfile.mkdtemp() + '/statistics'

        with beam.Pipeline() as p:
            provider = statistics_io_impl.get_io_provider('tfrecords')
            _ = (p | beam.Create(datasets)
                 | provider.record_sink_impl(output_prefix))

        got = provider.record_iterator_impl(provider.glob(output_prefix))
        self.assertCountEqual(datasets, got)
コード例 #4
0
def load_stats_tfrecord(
        input_path: Text) -> statistics_pb2.DatasetFeatureStatisticsList:
    """Loads data statistics proto from TFRecord file.

  Args:
    input_path: Data statistics file path.

  Returns:
    A DatasetFeatureStatisticsList proto.
  """
    it = statistics_io_impl.get_io_provider('tfrecords').record_iterator_impl(
        [input_path])
    result = next(it)
    try:
        next(it)
        raise ValueError('load_stats_tfrecord expects a single record.')
    except StopIteration:
        return result
    except Exception as e:
        raise e
コード例 #5
0
    def __init__(
        self,
        binary_proto_path: str,
        records_path_prefix: str,
        io_provider: Optional[statistics_io_impl.StatisticsIOProvider] = None
    ) -> None:
        """Initializes the transform.

    Args:
      binary_proto_path: Output path for writing statistics as a binary proto.
      records_path_prefix: File pattern for writing statistics to sharded
        records.
      io_provider: Optional StatisticsIOProvider. If unset, a default will be
      constructed. This argument determines the format of statistics output.
    """
        self._binary_proto_path = binary_proto_path
        self._records_path_prefix = records_path_prefix
        if io_provider is None:
            io_provider = statistics_io_impl.get_io_provider()
        self._io_provider = io_provider
コード例 #6
0
    def __init__(
        self,
        binary_proto_path: str,
        records_path_prefix_no_suffix: str,
    ) -> None:
        """Initialize WriteStatisticsBinaryAndMaybeRecords.

    Args:
      binary_proto_path: Output path for writing statistics as a binary proto.
      records_path_prefix_no_suffix: File pattern for writing statistics to
        sharded records. An appropriate file type suffix (e.g., .tfrecords) and
        shard numbers will be added.
    """
        if statistics_io_impl.should_write_sharded():
            io_provider = statistics_io_impl.get_io_provider()
            records_path_prefix = (records_path_prefix_no_suffix +
                                   io_provider.file_suffix())
            self._output_transform = WriteStatisticsToRecordsAndBinaryFile(
                binary_proto_path, records_path_prefix, io_provider)
        else:
            self._output_transform = WriteStatisticsToBinaryFile(
                binary_proto_path)