def test_sklearn_mi_with_slicing(self):
        sliced_record_batches = []
        for slice_key in ['slice1', 'slice2']:
            for record_batch in self.record_batches:
                sliced_record_batches.append((slice_key, record_batch))

        expected_result = [
            ('slice1',
             _get_test_stats_with_mi([
                 types.FeaturePath(['fa']),
                 types.FeaturePath(['fb']),
                 types.FeaturePath(['fd'])
             ])),
            ('slice2',
             _get_test_stats_with_mi([
                 types.FeaturePath(['fa']),
                 types.FeaturePath(['fb']),
                 types.FeaturePath(['fd'])
             ])),
        ]
        generator = partitioned_stats_generator.NonStreamingCustomStatsGenerator(
            sklearn_mutual_information.SkLearnMutualInformation(
                label_feature=types.FeaturePath(['label_key']),
                schema=self.schema,
                seed=TEST_SEED),
            num_partitions=2,
            min_partitions_stat_presence=2,
            seed=TEST_SEED,
            max_examples_per_partition=1000,
            batch_size=1,
            name='NonStreaming Mutual Information')
        self.assertSlicingAwareTransformOutputEqual(sliced_record_batches,
                                                    generator, expected_result)
 def test_sklearn_mi(self):
     expected_result = [
         _get_test_stats_with_mi([
             types.FeaturePath(['fa']),
             types.FeaturePath(['fb']),
             types.FeaturePath(['fd'])
         ])
     ]
     generator = partitioned_stats_generator.NonStreamingCustomStatsGenerator(
         sklearn_mutual_information.SkLearnMutualInformation(
             label_feature=types.FeaturePath(['label_key']),
             schema=self.schema,
             seed=TEST_SEED),
         num_partitions=2,
         min_partitions_stat_presence=2,
         seed=TEST_SEED,
         max_examples_per_partition=1000,
         batch_size=1,
         name='NonStreaming Mutual Information')
     self.assertSlicingAwareTransformOutputEqual(
         self.record_batches,
         generator,
         expected_result,
         add_default_slice_key_to_input=True,
         add_default_slice_key_to_output=True)
예제 #3
0
    def test_sklearn_mi(self):
        # Integration tests involving Beam and AMI are challenging to write
        # because Beam PCollections are unordered while the results of adjusted MI
        # depend on the order of the data for small datasets. This test case tests
        # MI with one label which will give a value of 0 regardless of
        # the ordering of elements in the PCollection. The purpose of this test is
        # to ensure that the Mutual Information pipeline is able to handle a
        # variety of input types. Unit tests ensuring correctness of the MI value
        # itself are included in sklearn_mutual_information_test.

        # fa is categorical, fb is numeric, fc is multivalent and fd has null values
        batches = [{
            'fa':
            np.array([
                np.array(['Red']),
                np.array(['Green']),
                np.array(['Blue']),
                np.array(['Green'])
            ]),
            'fb':
            np.array([
                np.array([1.0]),
                np.array([2.2]),
                np.array([3.3]),
                np.array([1.3])
            ]),
            'fc':
            np.array([
                np.array([1, 3, 1]),
                np.array([2, 6]),
                np.array([4, 6]), None
            ]),
            'fd':
            np.array([
                np.array([0.4]),
                np.array([0.4]),
                np.array([0.3]),
                np.array([0.2])
            ]),
            'label_key':
            np.array([
                np.array(['Label']),
                np.array(['Label']),
                np.array(['Label']),
                np.array(['Label'])
            ])
        }, {
            'fa':
            np.array([
                np.array(['Red']),
                np.array(['Blue']),
                np.array(['Blue']),
                np.array(['Green']),
                np.array(['Green'])
            ]),
            'fb':
            np.array([
                np.array([1.2]),
                np.array([0.5]),
                np.array([1.3]),
                np.array([2.3]),
                np.array([0.3])
            ]),
            'fc':
            np.array([
                np.array([1]),
                np.array([3, 2]),
                np.array([1, 4]),
                np.array([0]),
                np.array([3])
            ]),
            'fd':
            np.array([
                np.array([0.3]),
                np.array([0.4]),
                np.array([1.7]),
                np.array([np.NaN]),
                np.array([4.4])
            ]),
            'label_key':
            np.array([
                np.array(['Label']),
                np.array(['Label']),
                np.array(['Label']),
                np.array(['Label']),
                np.array(['Label'])
            ])
        }]

        schema = text_format.Parse(
            """
        feature {
          name: "fa"
          type: BYTES
          shape {
            dim {
              size: 1
            }
          }
        }
        feature {
          name: "fb"
          type: FLOAT
          shape {
            dim {
              size: 1
            }
          }
        }
        feature {
          name: "fc"
          type: INT
          value_count: {
            min: 0
            max: 2
          }
        }
        feature {
          name: "fd"
          type: FLOAT
          shape {
            dim {
              size: 1
            }
          }
        }
        feature {
          name: "label_key"
          type: BYTES
          shape {
            dim {
              size: 1
            }
          }
        }""", schema_pb2.Schema())

        expected_result = [
            text_format.Parse(
                """
              features {
                name: "fa"
                custom_stats {
                  name: "max_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "max_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "mean_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "mean_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "median_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "median_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "min_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "min_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "num_partitions_sklearn_adjusted_mutual_information"
                  num: 2.0
                }
                custom_stats {
                  name: "num_partitions_sklearn_mutual_information"
                  num: 2.0
                }
                custom_stats {
                  name: "std_dev_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "std_dev_sklearn_mutual_information"
                  num: 0.0
                }
              }
              features {
                name: "fb"
                custom_stats {
                  name: "max_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "max_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "mean_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "mean_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "median_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "median_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "min_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "min_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "num_partitions_sklearn_adjusted_mutual_information"
                  num: 2.0
                }
                custom_stats {
                  name: "num_partitions_sklearn_mutual_information"
                  num: 2.0
                }
                custom_stats {
                  name: "std_dev_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "std_dev_sklearn_mutual_information"
                  num: 0.0
                }
              }
              features {
                name: "fd"
                custom_stats {
                  name: "max_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "max_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "mean_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "mean_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "median_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "median_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "min_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "min_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "num_partitions_sklearn_adjusted_mutual_information"
                  num: 2.0
                }
                custom_stats {
                  name: "num_partitions_sklearn_mutual_information"
                  num: 2.0
                }
                custom_stats {
                  name: "std_dev_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "std_dev_sklearn_mutual_information"
                  num: 0.0
                }
              }""", statistics_pb2.DatasetFeatureStatistics())
        ]
        generator = partitioned_stats_generator.NonStreamingCustomStatsGenerator(
            sklearn_mutual_information.SkLearnMutualInformation(
                label_feature='label_key', schema=schema, seed=TEST_SEED),
            num_partitions=2,
            min_partitions_stat_presence=2,
            seed=TEST_SEED,
            max_examples_per_partition=1000,
            name='NonStreaming Mutual Information')
        self.assertTransformOutputEqual(batches, generator, expected_result)