def test_decode_example_empty_feature(self):
        example_proto_text = """
    features {
      feature { key: "int_feature" value { int64_list { value: [ 0 ] } } }
      feature { key: "int_feature_empty" value { } }
      feature { key: "float_feature" value { float_list { value: [ 4.0 ] } } }
      feature { key: "float_feature_empty" value { } }
      feature { key: "str_feature" value { bytes_list { value: [ 'male' ] } } }
      feature { key: "str_feature_empty" value { } }
    }
    """
        expected_decoded = {
            'int_feature': np.array([0], dtype=np.integer),
            'int_feature_empty': np.array([], dtype=np.object),
            'float_feature': np.array([4.0], dtype=np.floating),
            'float_feature_empty': np.array([], dtype=np.object),
            'str_feature': np.array([b'male'], dtype=np.object),
            'str_feature_empty': np.array([], dtype=np.object),
        }
        example = tf.train.Example()
        text_format.Merge(example_proto_text, example)

        decoder = tf_example_decoder.TFExampleDecoder()
        self._check_decoding_results(
            decoder.decode(example.SerializeToString()), expected_decoded)
def generate_statistics_from_tfrecord(
        data_location,
        output_path=None,
        stats_options=options.StatsOptions(),
        pipeline_options=None,
):
    """Compute data statistics from TFRecord files containing TFExamples.

  Runs a Beam pipeline to compute the data statistics and return the result
  data statistics proto.

  This is a convenience method for users with data in TFRecord format.
  Users with data in unsupported file/data formats, or users who wish
  to create their own Beam pipelines need to use the 'GenerateStatistics'
  PTransform API directly instead.

  Args:
    data_location: The location of the input data files.
    output_path: The file path to output data statistics result to. If None, we
      use a temporary directory. It will be a TFRecord file containing a single
      data statistics proto, and can be read with the 'load_statistics' API.
    stats_options: Options for generating data statistics.
    pipeline_options: Optional beam pipeline options. This allows users to
      specify various beam pipeline execution parameters like pipeline runner
      (DirectRunner or DataflowRunner), cloud dataflow service project id, etc.
      See https://cloud.google.com/dataflow/pipelines/specifying-exec-params for
      more details.

  Returns:
    A DatasetFeatureStatisticsList proto.
  """
    if output_path is None:
        output_path = os.path.join(tempfile.mkdtemp(), 'data_stats.tfrecord')
    output_dir_path = os.path.dirname(output_path)
    if not tf.gfile.Exists(output_dir_path):
        tf.gfile.MakeDirs(output_dir_path)

    # PyLint doesn't understand Beam PTransforms.
    # pylint: disable=no-value-for-parameter
    with beam.Pipeline(options=pipeline_options) as p:
        # Auto detect tfrecord file compression format based on input data
        # path suffix.
        _ = (
            p
            |
            'ReadData' >> beam.io.ReadFromTFRecord(file_pattern=data_location)
            | 'DecodeData' >> beam.Map(
                tf_example_decoder.TFExampleDecoder().decode)
            |
            'GenerateStatistics' >> stats_api.GenerateStatistics(stats_options)
            | 'WriteStatsOutput' >> beam.io.WriteToTFRecord(
                output_path,
                shard_name_template='',
                coder=beam.coders.ProtoCoder(
                    statistics_pb2.DatasetFeatureStatisticsList)))
    return load_statistics(output_path)
 def test_decode_example_none_ref_count(self):
     example = text_format.Parse(
         '''
       features {
         feature {
           key: 'x'
           value { }
         }
       }
     ''', tf.train.Example())
     before_refcount = sys.getrefcount(None)
     _ = tf_example_decoder.TFExampleDecoder().decode(
         example.SerializeToString())
     after_refcount = sys.getrefcount(None)
     self.assertEqual(before_refcount + 1, after_refcount)
 def test_decode_example(self, example_proto_text, decoded_example):
     example = tf.train.Example()
     text_format.Merge(example_proto_text, example)
     decoder = tf_example_decoder.TFExampleDecoder()
     self._check_decoding_results(
         decoder.decode(example.SerializeToString()), decoded_example)
 def test_decode_example_empty_input(self):
     example = tf.train.Example()
     decoder = tf_example_decoder.TFExampleDecoder()
     self._check_decoding_results(
         decoder.decode(example.SerializeToString()), {})