def test_sklearn_mi(self): expected_result = [ _get_test_stats_with_mi([ types.FeaturePath(['fa']), types.FeaturePath(['fb']), types.FeaturePath(['fd']) ]) ] generator = partitioned_stats_generator.NonStreamingCustomStatsGenerator( sklearn_mutual_information.SkLearnMutualInformation( label_feature=types.FeaturePath(['label_key']), schema=self.schema, seed=TEST_SEED), num_partitions=2, min_partitions_stat_presence=2, seed=TEST_SEED, max_examples_per_partition=1000, batch_size=1, name='NonStreaming Mutual Information') self.assertSlicingAwareTransformOutputEqual( self.record_batches, generator, expected_result, add_default_slice_key_to_input=True, add_default_slice_key_to_output=True)
def test_sklearn_mi_with_slicing(self): sliced_record_batches = [] for slice_key in ['slice1', 'slice2']: for record_batch in self.record_batches: sliced_record_batches.append((slice_key, record_batch)) expected_result = [ ('slice1', _get_test_stats_with_mi([ types.FeaturePath(['fa']), types.FeaturePath(['fb']), types.FeaturePath(['fd']) ])), ('slice2', _get_test_stats_with_mi([ types.FeaturePath(['fa']), types.FeaturePath(['fb']), types.FeaturePath(['fd']) ])), ] generator = partitioned_stats_generator.NonStreamingCustomStatsGenerator( sklearn_mutual_information.SkLearnMutualInformation( label_feature=types.FeaturePath(['label_key']), schema=self.schema, seed=TEST_SEED), num_partitions=2, min_partitions_stat_presence=2, seed=TEST_SEED, max_examples_per_partition=1000, batch_size=1, name='NonStreaming Mutual Information') self.assertSlicingAwareTransformOutputEqual(sliced_record_batches, generator, expected_result)
def test_mi_with_missing_label_key(self): batch = pa.RecordBatch.from_arrays( [pa.array([[1]]), pa.array([[1]])], ["label", "fa"]) schema = text_format.Parse( """ feature { name: "fa" type: FLOAT shape { dim { size: 1 } } } feature { name: "label" type: FLOAT shape { dim { size: 1 } } } """, schema_pb2.Schema()) with self.assertRaisesRegexp(ValueError, "Feature label_key not found in the schema."): sklearn_mutual_information.SkLearnMutualInformation( types.FeaturePath(["label_key"]), schema, TEST_SEED).compute(batch)
def test_mi_with_multivalent_label(self): batch = pa.RecordBatch.from_arrays( [pa.array([[1, 2]]), pa.array([[1]])], ["label_key", "fa"]) schema = text_format.Parse( """ feature { name: "fa" type: FLOAT shape { dim { size: 1 } } } feature { name: "label_key" type: FLOAT value_count: { min: 1 max: 2 } } """, schema_pb2.Schema()) with self.assertRaisesRegexp(ValueError, "Label column contains unsupported data."): sklearn_mutual_information.SkLearnMutualInformation( types.FeaturePath(["label_key"]), schema, TEST_SEED).compute(batch)
def test_mi_with_invalid_features(self): batch = pa.Table.from_arrays( [pa.array([[1]]), pa.array([[1, 2]])], ["label_key", "multivalent_feature"]) schema = text_format.Parse( """ feature { name: "label_key" type: INT shape { dim { size: 1 } } } feature { name: "multivalent_feature" type: INT value_count: { min: 2 max: 2 } } """, schema_pb2.Schema()) with self.assertRaisesRegexp(ValueError, "Found array with 0 sample"): sklearn_mutual_information.SkLearnMutualInformation( types.FeaturePath(["label_key"]), schema, TEST_SEED).compute(batch)
def test_mi_with_multivalent_label(self): batch = { "fa": [np.array([1.0]), np.array([2.0])], "label_key": [np.array([1.0, 2.0]), np.array([2.0])] } schema = text_format.Parse( """ feature { name: "fa" type: FLOAT shape { dim { size: 1 } } } feature { name: "label_key" type: FLOAT value_count: { min: 1 max: 2 } } """, schema_pb2.Schema()) with self.assertRaisesRegexp(ValueError, "Label column contains unsupported data."): sklearn_mutual_information.SkLearnMutualInformation( "label_key", schema, TEST_SEED).compute(batch)
def test_mi_with_missing_label_key(self): batch = { "fa": [np.array([1.0]), np.array([2.0])], "label": [np.array([1.0]), np.array([2.0])] } schema = text_format.Parse( """ feature { name: "fa" type: FLOAT shape { dim { size: 1 } } } feature { name: "label" type: FLOAT shape { dim { size: 1 } } } """, schema_pb2.Schema()) with self.assertRaisesRegexp(ValueError, "Feature label_key not found in the schema."): sklearn_mutual_information.SkLearnMutualInformation( "label_key", schema, TEST_SEED).compute(batch)
def test_mi_with_invalid_features(self): batch = { "label_key": np.array([np.array([1])]), "multivalent_feature": np.array([np.array([1, 2])]) } schema = text_format.Parse( """ feature { name: "label_key" type: INT shape { dim { size: 1 } } } feature { name: "multivalent_feature" type: INT value_count: { min: 2 max: 2 } } """, schema_pb2.Schema()) with self.assertRaisesRegexp(ValueError, "Found array with 0 sample"): sklearn_mutual_information.SkLearnMutualInformation( "label_key", schema, TEST_SEED).compute(batch)
def _assert_mi_output_equal(self, batch, expected, schema, label_feature): """Checks that MI computation is correct.""" actual = sklearn_mutual_information.SkLearnMutualInformation( label_feature, schema, TEST_SEED).compute(batch) compare.assertProtoEqual(self, actual, expected, normalize_numbers=True)
def test_sklearn_mi(self): # Integration tests involving Beam and AMI are challenging to write # because Beam PCollections are unordered while the results of adjusted MI # depend on the order of the data for small datasets. This test case tests # MI with one label which will give a value of 0 regardless of # the ordering of elements in the PCollection. The purpose of this test is # to ensure that the Mutual Information pipeline is able to handle a # variety of input types. Unit tests ensuring correctness of the MI value # itself are included in sklearn_mutual_information_test. # fa is categorical, fb is numeric, fc is multivalent and fd has null values batches = [{ 'fa': np.array([ np.array(['Red']), np.array(['Green']), np.array(['Blue']), np.array(['Green']) ]), 'fb': np.array([ np.array([1.0]), np.array([2.2]), np.array([3.3]), np.array([1.3]) ]), 'fc': np.array([ np.array([1, 3, 1]), np.array([2, 6]), np.array([4, 6]), None ]), 'fd': np.array([ np.array([0.4]), np.array([0.4]), np.array([0.3]), np.array([0.2]) ]), 'label_key': np.array([ np.array(['Label']), np.array(['Label']), np.array(['Label']), np.array(['Label']) ]) }, { 'fa': np.array([ np.array(['Red']), np.array(['Blue']), np.array(['Blue']), np.array(['Green']), np.array(['Green']) ]), 'fb': np.array([ np.array([1.2]), np.array([0.5]), np.array([1.3]), np.array([2.3]), np.array([0.3]) ]), 'fc': np.array([ np.array([1]), np.array([3, 2]), np.array([1, 4]), np.array([0]), np.array([3]) ]), 'fd': np.array([ np.array([0.3]), np.array([0.4]), np.array([1.7]), np.array([np.NaN]), np.array([4.4]) ]), 'label_key': np.array([ np.array(['Label']), np.array(['Label']), np.array(['Label']), np.array(['Label']), np.array(['Label']) ]) }] schema = text_format.Parse( """ feature { name: "fa" type: BYTES shape { dim { size: 1 } } } feature { name: "fb" type: FLOAT shape { dim { size: 1 } } } feature { name: "fc" type: INT value_count: { min: 0 max: 2 } } feature { name: "fd" type: FLOAT shape { dim { size: 1 } } } feature { name: "label_key" type: BYTES shape { dim { size: 1 } } }""", schema_pb2.Schema()) expected_result = [ text_format.Parse( """ features { name: "fa" custom_stats { name: "max_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "max_sklearn_mutual_information" num: 0.0 } custom_stats { name: "mean_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "mean_sklearn_mutual_information" num: 0.0 } custom_stats { name: "median_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "median_sklearn_mutual_information" num: 0.0 } custom_stats { name: "min_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "min_sklearn_mutual_information" num: 0.0 } custom_stats { name: "num_partitions_sklearn_adjusted_mutual_information" num: 2.0 } custom_stats { name: "num_partitions_sklearn_mutual_information" num: 2.0 } custom_stats { name: "std_dev_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "std_dev_sklearn_mutual_information" num: 0.0 } } features { name: "fb" custom_stats { name: "max_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "max_sklearn_mutual_information" num: 0.0 } custom_stats { name: "mean_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "mean_sklearn_mutual_information" num: 0.0 } custom_stats { name: "median_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "median_sklearn_mutual_information" num: 0.0 } custom_stats { name: "min_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "min_sklearn_mutual_information" num: 0.0 } custom_stats { name: "num_partitions_sklearn_adjusted_mutual_information" num: 2.0 } custom_stats { name: "num_partitions_sklearn_mutual_information" num: 2.0 } custom_stats { name: "std_dev_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "std_dev_sklearn_mutual_information" num: 0.0 } } features { name: "fd" custom_stats { name: "max_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "max_sklearn_mutual_information" num: 0.0 } custom_stats { name: "mean_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "mean_sklearn_mutual_information" num: 0.0 } custom_stats { name: "median_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "median_sklearn_mutual_information" num: 0.0 } custom_stats { name: "min_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "min_sklearn_mutual_information" num: 0.0 } custom_stats { name: "num_partitions_sklearn_adjusted_mutual_information" num: 2.0 } custom_stats { name: "num_partitions_sklearn_mutual_information" num: 2.0 } custom_stats { name: "std_dev_sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "std_dev_sklearn_mutual_information" num: 0.0 } }""", statistics_pb2.DatasetFeatureStatistics()) ] generator = partitioned_stats_generator.NonStreamingCustomStatsGenerator( sklearn_mutual_information.SkLearnMutualInformation( label_feature='label_key', schema=schema, seed=TEST_SEED), num_partitions=2, min_partitions_stat_presence=2, seed=TEST_SEED, max_examples_per_partition=1000, name='NonStreaming Mutual Information') self.assertTransformOutputEqual(batches, generator, expected_result)