def test_sklearn_mi_with_slicing(self): sliced_record_batches = [] for slice_key in ['slice1', 'slice2']: for record_batch in self.record_batches: sliced_record_batches.append((slice_key, record_batch)) expected_result = [ ('slice1', _get_test_stats_with_mi([ types.FeaturePath(['fa']), types.FeaturePath(['fb']), types.FeaturePath(['fd']) ])), ('slice2', _get_test_stats_with_mi([ types.FeaturePath(['fa']), types.FeaturePath(['fb']), types.FeaturePath(['fd']) ])), ] generator = partitioned_stats_generator.NonStreamingCustomStatsGenerator( sklearn_mutual_information.SkLearnMutualInformation( label_feature=types.FeaturePath(['label_key']), schema=self.schema, seed=TEST_SEED), num_partitions=2, min_partitions_stat_presence=2, seed=TEST_SEED, max_examples_per_partition=1000, batch_size=1, name='NonStreaming Mutual Information') self.assertSlicingAwareTransformOutputEqual(sliced_record_batches, generator, expected_result)
def test_sklearn_mi(self): expected_result = [ _get_test_stats_with_mi([ types.FeaturePath(['fa']), types.FeaturePath(['fb']), types.FeaturePath(['fd']) ]) ] generator = partitioned_stats_generator.NonStreamingCustomStatsGenerator( sklearn_mutual_information.SkLearnMutualInformation( label_feature=types.FeaturePath(['label_key']), schema=self.schema, seed=TEST_SEED), num_partitions=2, min_partitions_stat_presence=2, seed=TEST_SEED, max_examples_per_partition=1000, batch_size=1, name='NonStreaming Mutual Information') self.assertSlicingAwareTransformOutputEqual( self.record_batches, generator, expected_result, add_default_slice_key_to_input=True, add_default_slice_key_to_output=True)
def test_sklearn_mi(self): # Integration tests involving Beam and AMI are challenging to write # because Beam PCollections are unordered while the results of adjusted MI # depend on the order of the data for small datasets. This test case tests # MI with one label which will give a value of 0 regardless of # the ordering of elements in the PCollection. The purpose of this test is # to ensure that the Mutual Information pipeline is able to handle a # variety of input types. Unit tests ensuring correctness of the MI value # itself are included in sklearn_mutual_information_test. # fa is categorical, fb is numeric, fc is multivalent and fd has null values batches = [{ 'fa': np.array([ np.array(['Red']), np.array(['Green']), np.array(['Blue']), np.array(['Green']) ]), 'fb': np.array([ np.array([1.0]), np.array([2.2]), np.array([3.3]), np.array([1.3]) ]), 'fc': np.array([ np.array([1, 3, 1]), np.array([2, 6]), np.array([4, 6]), None ]), 'fd': np.array([ np.array([0.4]), np.array([0.4]), np.array([0.3]), np.array([0.2]) ]), 'label_key': np.array([ np.array(['Label']), np.array(['Label']), np.array(['Label']), np.array(['Label']) ]) }, { 'fa': np.array([ np.array(['Red']), np.array(['Blue']), np.array(['Blue']), np.array(['Green']), np.array(['Green']) ]), 'fb': np.array([ np.array([1.2]), np.array([0.5]), np.array([1.3]), np.array([2.3]), np.array([0.3]) ]), 'fc': np.array([ np.array([1]), np.array([3, 2]), np.array([1, 4]), np.array([0]), np.array([3]) ]), 'fd': np.array([ np.array([0.3]), np.array([0.4]), np.array([1.7]), np.array([np.NaN]), np.array([4.4]) ]), 'label_key': np.array([ np.array(['Label']), np.array(['Label']), np.array(['Label']), np.array(['Label']), np.array(['Label']) ]) }] schema = text_format.Parse( """ feature { name: "fa" type: BYTES shape { dim { size: 1 } } } feature { name: "fb" type: FLOAT shape { dim { size: 1 } } } feature { name: "fc" type: INT value_count: { min: 0 max: 2 } } feature { name: "fd" type: FLOAT shape { dim { size: 1 } } } feature { name: "label_key" type: BYTES shape { dim { size: 1 } } }""", schema_pb2.Schema()) expected_result = [ text_format.Parse( """ features { name: "fa" custom_stats { name: "max_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "max_sklearn_mutual_information" num: 0.0 } custom_stats { name: "mean_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "mean_sklearn_mutual_information" num: 0.0 } custom_stats { name: "median_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "median_sklearn_mutual_information" num: 0.0 } custom_stats { name: "min_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "min_sklearn_mutual_information" num: 0.0 } custom_stats { name: "num_partitions_sklearn_adjusted_mutual_information" num: 2.0 } custom_stats { name: "num_partitions_sklearn_mutual_information" num: 2.0 } custom_stats { name: "std_dev_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "std_dev_sklearn_mutual_information" num: 0.0 } } features { name: "fb" custom_stats { name: "max_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "max_sklearn_mutual_information" num: 0.0 } custom_stats { name: "mean_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "mean_sklearn_mutual_information" num: 0.0 } custom_stats { name: "median_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "median_sklearn_mutual_information" num: 0.0 } custom_stats { name: "min_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "min_sklearn_mutual_information" num: 0.0 } custom_stats { name: "num_partitions_sklearn_adjusted_mutual_information" num: 2.0 } custom_stats { name: "num_partitions_sklearn_mutual_information" num: 2.0 } custom_stats { name: "std_dev_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "std_dev_sklearn_mutual_information" num: 0.0 } } features { name: "fd" custom_stats { name: "max_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "max_sklearn_mutual_information" num: 0.0 } custom_stats { name: "mean_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "mean_sklearn_mutual_information" num: 0.0 } custom_stats { name: "median_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "median_sklearn_mutual_information" num: 0.0 } custom_stats { name: "min_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "min_sklearn_mutual_information" num: 0.0 } custom_stats { name: "num_partitions_sklearn_adjusted_mutual_information" num: 2.0 } custom_stats { name: "num_partitions_sklearn_mutual_information" num: 2.0 } custom_stats { name: "std_dev_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "std_dev_sklearn_mutual_information" num: 0.0 } }""", statistics_pb2.DatasetFeatureStatistics()) ] generator = partitioned_stats_generator.NonStreamingCustomStatsGenerator( sklearn_mutual_information.SkLearnMutualInformation( label_feature='label_key', schema=schema, seed=TEST_SEED), num_partitions=2, min_partitions_stat_presence=2, seed=TEST_SEED, max_examples_per_partition=1000, name='NonStreaming Mutual Information') self.assertTransformOutputEqual(batches, generator, expected_result)