示例#1
0
    def test_infer_schema_with_infer_shape(self):
        statistics = text_format.Parse(
            """
        datasets {
          num_examples: 7
          features: {
            name: 'feature1'
            type: STRING
            string_stats: {
              common_stats: {
                num_missing: 0
                num_non_missing: 7
                min_num_values: 1
                max_num_values: 1
              }
              unique: 3
            }
          }
          features: {
            name: 'feature2'
            type: STRING
            string_stats: {
              common_stats: {
                num_missing: 0
                num_non_missing: 7
                min_num_values: 3
                max_num_values: 3
              }
              unique: 5
            }
          }
        }
        """, statistics_pb2.DatasetFeatureStatisticsList())

        expected_schema = text_format.Parse(
            """
        feature {
          name: "feature1"
          shape { dim { size: 1 } }
          presence: {
            min_fraction: 1.0
            min_count: 1
          }
          type: BYTES
        }
        feature {
          name: "feature2"
          value_count: { min: 1 }
          presence: {
            min_fraction: 1.0
            min_count: 1
          }
          type: BYTES
        }
        """, schema_pb2.Schema())

        # Infer the schema from the stats.
        actual_schema = validation_api.infer_schema(statistics,
                                                    infer_feature_shape=True)
        self.assertEqual(actual_schema, expected_schema)
示例#2
0
    def test_infer_schema_without_string_domain(self):
        statistics = text_format.Parse(
            """
        datasets {
          num_examples: 7
          features: {
            name: 'feature1'
            type: STRING
            string_stats: {
              common_stats: {
                num_missing: 3
                num_non_missing: 4
                min_num_values: 1
                max_num_values: 1
              }
              unique: 3
              rank_histogram {
                buckets {
                  low_rank: 0
                  high_rank: 0
                  label: "a"
                  sample_count: 2.0
                }
                buckets {
                  low_rank: 1
                  high_rank: 1
                  label: "b"
                  sample_count: 1.0
                }
                buckets {
                  low_rank: 2
                  high_rank: 2
                  label: "c"
                  sample_count: 1.0
                }
              }
            }
          }
        }
        """, statistics_pb2.DatasetFeatureStatisticsList())

        expected_schema = text_format.Parse(
            """
        feature {
          name: "feature1"
          value_count: {
            min: 1
            max: 1
          }
          presence: {
            min_count: 1
          }
          type: BYTES
        }
        """, schema_pb2.Schema())

        # Infer the schema from the stats.
        actual_schema = validation_api.infer_schema(statistics,
                                                    max_string_domain_size=2)
        self.assertEqual(actual_schema, expected_schema)
 def test_infer_schema_invalid_multiple_datasets(self):
     statistics = statistics_pb2.DatasetFeatureStatisticsList()
     statistics.datasets.extend([
         statistics_pb2.DatasetFeatureStatistics(),
         statistics_pb2.DatasetFeatureStatistics()
     ])
     with self.assertRaisesRegexp(ValueError,
                                  '.*statistics proto with one dataset.*'):
         _ = validation_api.infer_schema(statistics)
    def test_infer_schema(self):
        statistics = text_format.Parse(
            """
        datasets {
          num_examples: 7
          features: {
            name: 'feature1'
            type: STRING
            string_stats: {
              common_stats: {
                num_non_missing: 7
                min_num_values: 1
                max_num_values: 1
              }
              unique: 3
            }
          }
        }
        """, statistics_pb2.DatasetFeatureStatisticsList())

        expected_schema = text_format.Parse(
            """
        feature {
          name: "feature1"
          value_count: {
            min: 1
            max: 1
          }
          presence: {
            min_fraction: 1.0
            min_count: 1
          }
          type: BYTES
        }
        """, schema_pb2.Schema())
        validation_api._may_be_set_legacy_flag(expected_schema)

        # Infer the schema from the stats.
        actual_schema = validation_api.infer_schema(statistics,
                                                    infer_feature_shape=False)
        self.assertEqual(actual_schema, expected_schema)
 def test_infer_schema_invalid_statistics_input(self):
     with self.assertRaisesRegexp(
             TypeError,
             '.*should be a DatasetFeatureStatisticsList proto.*'):
         _ = validation_api.infer_schema({})