Example #1
0
def get_generators(
        options: stats_options.StatsOptions,
        in_memory: bool = False) -> List[stats_generator.StatsGenerator]:
    """Initializes the list of stats generators, including custom generators.

  Args:
    options: A StatsOptions object.
    in_memory: Whether the generators will be used to generate statistics in
      memory (True) or using Beam (False).

  Returns:
    A list of stats generator objects.
  """
    generators = [NumExamplesStatsGenerator(options.weight_feature)]
    if options.add_default_generators:
        generators.extend(_get_default_generators(options, in_memory))
    if options.generators:
        # Add custom stats generators.
        generators.extend(options.generators)
    if options.enable_semantic_domain_stats:
        semantic_domain_feature_stats_generators = [
            image_stats_generator.ImageStatsGenerator(),
            natural_language_domain_inferring_stats_generator.
            NLDomainInferringStatsGenerator(),
            time_stats_generator.TimeStatsGenerator(),
        ]
        # Wrap semantic domain feature stats generators as a separate combiner
        # stats generator, so that we can apply sampling only for those and other
        # feature stats generators are not affected by it.
        generators.append(
            CombinerFeatureStatsWrapperGenerator(
                semantic_domain_feature_stats_generators,
                sample_rate=options.semantic_domain_stats_sample_rate))
    if options.schema is not None:
        if _schema_has_sparse_features(options.schema):
            generators.append(
                sparse_feature_stats_generator.SparseFeatureStatsGenerator(
                    options.schema))
        if _schema_has_natural_language_domains(options.schema):
            generators.append(
                natural_language_stats_generator.NLStatsGenerator(
                    options.schema, options.vocab_paths,
                    options.num_histogram_buckets,
                    options.num_quantiles_histogram_buckets,
                    options.num_rank_histogram_buckets))
        if options.schema.weighted_feature:
            generators.append(
                weighted_feature_stats_generator.WeightedFeatureStatsGenerator(
                    options.schema))
        if options.label_feature and not in_memory:
            # The LiftStatsGenerator is not a CombinerStatsGenerator and therefore
            # cannot currenty be used for in_memory executions.
            generators.append(
                lift_stats_generator.LiftStatsGenerator(
                    y_path=types.FeaturePath([options.label_feature]),
                    schema=options.schema,
                    example_weight_map=options.example_weight_map,
                    output_custom_stats=True))

    # Replace all CombinerFeatureStatsGenerator with a single
    # CombinerFeatureStatsWrapperGenerator.
    feature_generators = [
        x for x in generators
        if isinstance(x, stats_generator.CombinerFeatureStatsGenerator)
    ]
    if feature_generators:
        generators = [
            x for x in generators
            if not isinstance(x, stats_generator.CombinerFeatureStatsGenerator)
        ] + [CombinerFeatureStatsWrapperGenerator(feature_generators)]
    if in_memory:
        for generator in generators:
            if not isinstance(generator,
                              stats_generator.CombinerStatsGenerator):
                raise TypeError(
                    'Statistics generator used in '
                    'generate_statistics_in_memory must '
                    'extend CombinerStatsGenerator, found object of '
                    'type %s.' % generator.__class__.__name__)
    return generators
Example #2
0
  def test_image_stats_generator_check_is_image_ratio(self):
    """Check is_image_ratio with a feature that has partiallly images."""
    # The image ratio is: 0.83
    batches = [
        [
            np.array([
                FakeImageDecoder.encode_image_metadata('PNG', 2, 4),
                FakeImageDecoder.encode_image_metadata('JPEG', 4, 2)
            ]),
            np.array([
                FakeImageDecoder.encode_image_metadata('TIFF', 5, 1),
                FakeImageDecoder.encode_image_metadata('', -1, -1),
                FakeImageDecoder.encode_image_metadata('TIFF', 3, 7)
            ]),
        ],
        [
            np.array([FakeImageDecoder.encode_image_metadata('GIF', 2, 1)]),
        ],
    ]
    # For image_ratio_threshold=0.85 we for not expect stats.
    expected_result = statistics_pb2.FeatureNameStatistics()
    image_decoder = FakeImageDecoder()
    generator = image_stats_generator.ImageStatsGenerator(
        image_decoder=image_decoder,
        is_image_ratio_threshold=0.85,
        examples_threshold=1,
        enable_size_stats=True)
    self.assertCombinerOutputEqual(batches, generator, expected_result)

    # For image_ratio_threshold=0.8 we expect stats.
    expected_result = text_format.Parse(
        """
            custom_stats {
              name: 'domain_info'
              str: 'image_domain {}'
            }
            custom_stats {
              name: 'image_format_histogram'
              rank_histogram {
                buckets {
                  label: 'UNKNOWN'
                  sample_count: 1
                }
                buckets {
                  label: 'GIF'
                  sample_count: 1
                }
                buckets {
                  label: 'JPEG'
                  sample_count: 1
                }
                buckets {
                  label: 'PNG'
                  sample_count: 1
                }
                buckets {
                  label: 'TIFF'
                  sample_count: 2
                }
              }
            }
            custom_stats {
              name: 'image_max_width'
              num: 7.0
            }
            custom_stats {
              name: 'image_max_height'
              num: 5.0
            }
            """, statistics_pb2.FeatureNameStatistics())
    generator = image_stats_generator.ImageStatsGenerator(
        image_decoder=image_decoder,
        is_image_ratio_threshold=0.8,
        examples_threshold=1,
        enable_size_stats=True)
    self.assertCombinerOutputEqual(batches, generator, expected_result)
Example #3
0
  def test_image_stats_generator_examples_threshold_check(self):
    """Check examples_threshold with a feature that is all images."""
    batches = [[
        np.array([
            FakeImageDecoder.encode_image_metadata('PNG', 2, 4),
            FakeImageDecoder.encode_image_metadata('JPEG', 4, 2)
        ]),
        np.array([
            FakeImageDecoder.encode_image_metadata('TIFF', 5, 1),
            FakeImageDecoder.encode_image_metadata('JPEG', 1, 1),
            FakeImageDecoder.encode_image_metadata('TIFF', 3, 7)
        ]),
    ], [
        np.array([FakeImageDecoder.encode_image_metadata('GIF', 2, 1)]),
    ]]

    # With examples_threshold = 7 statistics should not be generated.
    image_decoder = FakeImageDecoder()
    expected_result = statistics_pb2.FeatureNameStatistics()
    generator = image_stats_generator.ImageStatsGenerator(
        image_decoder=image_decoder,
        examples_threshold=7,
        enable_size_stats=True)
    self.assertCombinerOutputEqual(batches, generator, expected_result)

    # With examples_threshold = 6 statistics should be generated.
    expected_result = text_format.Parse(
        """
            custom_stats {
              name: 'domain_info'
              str: 'image_domain {}'
            }
            custom_stats {
              name: 'image_format_histogram'
              rank_histogram {
                buckets {
                  label: 'GIF'
                  sample_count: 1
                }
                buckets {
                  label: 'JPEG'
                  sample_count: 2
                }
                buckets {
                  label: 'PNG'
                  sample_count: 1
                }
                buckets {
                  label: 'TIFF'
                  sample_count: 2
                }
              }
            }
            custom_stats {
              name: 'image_max_width'
              num: 7.0
            }
            custom_stats {
              name: 'image_max_height'
              num: 5.0
            }
            """, statistics_pb2.FeatureNameStatistics())
    generator = image_stats_generator.ImageStatsGenerator(
        image_decoder=image_decoder,
        examples_threshold=6,
        enable_size_stats=True)
    self.assertCombinerOutputEqual(batches, generator, expected_result)
Example #4
0
 def test_image_stats_generator_real_image(self):
     test_data_dir = os.path.join(os.path.dirname(__file__), 'testdata')
     batches = [
         pa.Column.from_array(
             'feature',
             pa.array([[
                 _read_file(os.path.join(test_data_dir, 'image1.gif')),
                 _read_file(os.path.join(test_data_dir, 'image2.png')),
                 _read_file(os.path.join(test_data_dir, 'not_a_image.abc'))
             ],
                       [
                           _read_file(
                               os.path.join(test_data_dir, 'image3.bmp')),
                           b'not_a_image'
                       ]])),
         pa.Column.from_array(
             'feature',
             pa.array([[
                 _read_file(os.path.join(test_data_dir, 'image4.png')),
             ]])),
     ]
     expected_result = text_format.Parse(
         """
         custom_stats {
           name: 'domain_info'
           str: 'image_domain {}'
         }
         custom_stats {
           name: 'image_format_histogram'
           rank_histogram {
             buckets {
               label: 'UNKNOWN'
               sample_count: 2
             }
             buckets {
               label: 'bmp'
               sample_count: 1
             }
             buckets {
               label: 'gif'
               sample_count: 1
             }
             buckets {
               label: 'png'
               sample_count: 2
             }
           }
         }
         custom_stats {
           name: 'image_max_width'
           num: 51.0
         }
         custom_stats {
           name: 'image_max_height'
           num: 26.0
         }
         """, statistics_pb2.FeatureNameStatistics())
     generator = image_stats_generator.ImageStatsGenerator(
         is_image_ratio_threshold=0.6,
         values_threshold=1,
         enable_size_stats=True)
     self.assertCombinerOutputEqual(batches, generator, expected_result)
Example #5
0
 def test_image_stats_generator_disable_size_stats(self):
     """Test the enable_size_stats_option."""
     # Identical input to test_image_stats_generator_check_is_image_ratio
     batches = [
         pa.Column.from_array(
             'feature',
             pa.array(
                 [[
                     FakeImageDecoder.encode_image_metadata('PNG', 2, 4),
                     FakeImageDecoder.encode_image_metadata('JPEG', 4, 2),
                 ],
                  [
                      FakeImageDecoder.encode_image_metadata('TIFF', 5, 1),
                      FakeImageDecoder.encode_image_metadata('', -1, -1),
                      FakeImageDecoder.encode_image_metadata('TIFF', 3, 7)
                  ]])),
         pa.Column.from_array(
             'feature',
             pa.array([[
                 FakeImageDecoder.encode_image_metadata('GIF', 2, 1),
             ]])),
     ]
     # Stats should be identical but without stats for image size.
     expected_result = text_format.Parse(
         """
         custom_stats {
           name: 'domain_info'
           str: 'image_domain {}'
         }
         custom_stats {
           name: 'image_format_histogram'
           rank_histogram {
             buckets {
               label: 'UNKNOWN'
               sample_count: 1
             }
             buckets {
               label: 'GIF'
               sample_count: 1
             }
             buckets {
               label: 'JPEG'
               sample_count: 1
             }
             buckets {
               label: 'PNG'
               sample_count: 1
             }
             buckets {
               label: 'TIFF'
               sample_count: 2
             }
           }
         }
         """, statistics_pb2.FeatureNameStatistics())
     image_decoder = FakeImageDecoder()
     generator = image_stats_generator.ImageStatsGenerator(
         image_decoder=image_decoder,
         is_image_ratio_threshold=0.8,
         values_threshold=1,
         enable_size_stats=False)
     self.assertCombinerOutputEqual(batches, generator, expected_result)
 def test_image_stats_generator_single_feature_all_images(self):
     # input with two batches: first batch has two examples and second batch
     # has a single example.
     batches = [{
         'a':
         np.array([
             np.array([
                 FakeImageDecoder.encode_image_metadata('PNG', 2, 4),
                 FakeImageDecoder.encode_image_metadata('JPEG', 4, 2)
             ]),
             np.array([
                 FakeImageDecoder.encode_image_metadata('TIFF', 5, 1),
                 FakeImageDecoder.encode_image_metadata('JPEG', 1, 1),
                 FakeImageDecoder.encode_image_metadata('TIFF', 3, 7)
             ])
         ])
     }, {
         'a':
         np.array([
             np.array([FakeImageDecoder.encode_image_metadata('GIF', 2, 1)])
         ])
     }]
     expected_result = {
         'a':
         text_format.Parse(
             """
         name: 'a'
         custom_stats {
           name: 'is_image'
           num: 1.0
         }
         custom_stats {
           name: 'max_image_width'
           num: 7.0
         }
         custom_stats {
           name: 'max_image_height'
           num: 5.0
         }
         custom_stats {
           name: 'image_format_histogram'
           rank_histogram {
             buckets {
               label: 'GIF'
               sample_count: 1
             }
             buckets {
               label: 'JPEG'
               sample_count: 2
             }
             buckets {
               label: 'PNG'
               sample_count: 1
             }
             buckets {
               label: 'TIFF'
               sample_count: 2
             }
           }
         }
         """, statistics_pb2.FeatureNameStatistics())
     }
     image_decoder = FakeImageDecoder()
     generator = image_stats_generator.ImageStatsGenerator(
         image_decoder=image_decoder)
     self.assertCombinerOutputEqual(batches, generator, expected_result)