def test_sklearn_mi(self):
     expected_result = [
         _get_test_stats_with_mi([
             types.FeaturePath(['fa']),
             types.FeaturePath(['fb']),
             types.FeaturePath(['fd'])
         ])
     ]
     generator = partitioned_stats_generator.NonStreamingCustomStatsGenerator(
         sklearn_mutual_information.SkLearnMutualInformation(
             label_feature=types.FeaturePath(['label_key']),
             schema=self.schema,
             seed=TEST_SEED),
         num_partitions=2,
         min_partitions_stat_presence=2,
         seed=TEST_SEED,
         max_examples_per_partition=1000,
         batch_size=1,
         name='NonStreaming Mutual Information')
     self.assertSlicingAwareTransformOutputEqual(
         self.record_batches,
         generator,
         expected_result,
         add_default_slice_key_to_input=True,
         add_default_slice_key_to_output=True)
    def test_sklearn_mi_with_slicing(self):
        sliced_record_batches = []
        for slice_key in ['slice1', 'slice2']:
            for record_batch in self.record_batches:
                sliced_record_batches.append((slice_key, record_batch))

        expected_result = [
            ('slice1',
             _get_test_stats_with_mi([
                 types.FeaturePath(['fa']),
                 types.FeaturePath(['fb']),
                 types.FeaturePath(['fd'])
             ])),
            ('slice2',
             _get_test_stats_with_mi([
                 types.FeaturePath(['fa']),
                 types.FeaturePath(['fb']),
                 types.FeaturePath(['fd'])
             ])),
        ]
        generator = partitioned_stats_generator.NonStreamingCustomStatsGenerator(
            sklearn_mutual_information.SkLearnMutualInformation(
                label_feature=types.FeaturePath(['label_key']),
                schema=self.schema,
                seed=TEST_SEED),
            num_partitions=2,
            min_partitions_stat_presence=2,
            seed=TEST_SEED,
            max_examples_per_partition=1000,
            batch_size=1,
            name='NonStreaming Mutual Information')
        self.assertSlicingAwareTransformOutputEqual(sliced_record_batches,
                                                    generator, expected_result)
Пример #3
0
  def test_mi_with_missing_label_key(self):
    batch = pa.RecordBatch.from_arrays(
        [pa.array([[1]]), pa.array([[1]])], ["label", "fa"])

    schema = text_format.Parse(
        """
          feature {
            name: "fa"
            type: FLOAT
              shape {
              dim {
                size: 1
              }
            }
          }
          feature {
            name: "label"
            type: FLOAT
            shape {
              dim {
                size: 1
              }
            }
          }
          """, schema_pb2.Schema())

    with self.assertRaisesRegexp(ValueError,
                                 "Feature label_key not found in the schema."):
      sklearn_mutual_information.SkLearnMutualInformation(
          types.FeaturePath(["label_key"]), schema, TEST_SEED).compute(batch)
Пример #4
0
  def test_mi_with_multivalent_label(self):
    batch = pa.RecordBatch.from_arrays(
        [pa.array([[1, 2]]), pa.array([[1]])], ["label_key", "fa"])
    schema = text_format.Parse(
        """
          feature {
            name: "fa"
            type: FLOAT
            shape {
              dim {
                size: 1
              }
            }
          }
          feature {
            name: "label_key"
            type: FLOAT
            value_count: {
              min: 1
              max: 2
            }
          }
          """, schema_pb2.Schema())

    with self.assertRaisesRegexp(ValueError,
                                 "Label column contains unsupported data."):
      sklearn_mutual_information.SkLearnMutualInformation(
          types.FeaturePath(["label_key"]), schema, TEST_SEED).compute(batch)
 def test_mi_with_invalid_features(self):
     batch = pa.Table.from_arrays(
         [pa.array([[1]]), pa.array([[1, 2]])],
         ["label_key", "multivalent_feature"])
     schema = text_format.Parse(
         """
     feature {
       name: "label_key"
       type: INT
       shape {
         dim {
           size: 1
         }
       }
     }
     feature {
       name: "multivalent_feature"
       type: INT
       value_count: {
         min: 2
         max: 2
       }
     }
     """, schema_pb2.Schema())
     with self.assertRaisesRegexp(ValueError, "Found array with 0 sample"):
         sklearn_mutual_information.SkLearnMutualInformation(
             types.FeaturePath(["label_key"]), schema,
             TEST_SEED).compute(batch)
Пример #6
0
  def test_mi_with_multivalent_label(self):
    batch = {
        "fa": [np.array([1.0]), np.array([2.0])],
        "label_key": [np.array([1.0, 2.0]), np.array([2.0])]
    }

    schema = text_format.Parse(
        """
          feature {
            name: "fa"
            type: FLOAT
            shape {
              dim {
                size: 1
              }
            }
          }
          feature {
            name: "label_key"
            type: FLOAT
            value_count: {
              min: 1
              max: 2
            }
          }
          """, schema_pb2.Schema())

    with self.assertRaisesRegexp(ValueError,
                                 "Label column contains unsupported data."):
      sklearn_mutual_information.SkLearnMutualInformation(
          "label_key", schema, TEST_SEED).compute(batch)
Пример #7
0
  def test_mi_with_missing_label_key(self):
    batch = {
        "fa": [np.array([1.0]), np.array([2.0])],
        "label": [np.array([1.0]), np.array([2.0])]
    }

    schema = text_format.Parse(
        """
          feature {
            name: "fa"
            type: FLOAT
              shape {
              dim {
                size: 1
              }
            }
          }
          feature {
            name: "label"
            type: FLOAT
            shape {
              dim {
                size: 1
              }
            }
          }
          """, schema_pb2.Schema())

    with self.assertRaisesRegexp(ValueError,
                                 "Feature label_key not found in the schema."):
      sklearn_mutual_information.SkLearnMutualInformation(
          "label_key", schema, TEST_SEED).compute(batch)
  def test_mi_with_invalid_features(self):
    batch = {
        "label_key": np.array([np.array([1])]),
        "multivalent_feature": np.array([np.array([1, 2])])
    }

    schema = text_format.Parse(
        """
        feature {
          name: "label_key"
          type: INT
          shape {
            dim {
              size: 1
            }
          }
        }
        feature {
          name: "multivalent_feature"
          type: INT
          value_count: {
            min: 2
            max: 2
          }
        }
        """, schema_pb2.Schema())
    with self.assertRaisesRegexp(ValueError, "Found array with 0 sample"):
      sklearn_mutual_information.SkLearnMutualInformation(
          "label_key", schema, TEST_SEED).compute(batch)
 def _assert_mi_output_equal(self, batch, expected, schema, label_feature):
     """Checks that MI computation is correct."""
     actual = sklearn_mutual_information.SkLearnMutualInformation(
         label_feature, schema, TEST_SEED).compute(batch)
     compare.assertProtoEqual(self,
                              actual,
                              expected,
                              normalize_numbers=True)
Пример #10
0
    def test_sklearn_mi(self):
        # Integration tests involving Beam and AMI are challenging to write
        # because Beam PCollections are unordered while the results of adjusted MI
        # depend on the order of the data for small datasets. This test case tests
        # MI with one label which will give a value of 0 regardless of
        # the ordering of elements in the PCollection. The purpose of this test is
        # to ensure that the Mutual Information pipeline is able to handle a
        # variety of input types. Unit tests ensuring correctness of the MI value
        # itself are included in sklearn_mutual_information_test.

        # fa is categorical, fb is numeric, fc is multivalent and fd has null values
        batches = [{
            'fa':
            np.array([
                np.array(['Red']),
                np.array(['Green']),
                np.array(['Blue']),
                np.array(['Green'])
            ]),
            'fb':
            np.array([
                np.array([1.0]),
                np.array([2.2]),
                np.array([3.3]),
                np.array([1.3])
            ]),
            'fc':
            np.array([
                np.array([1, 3, 1]),
                np.array([2, 6]),
                np.array([4, 6]), None
            ]),
            'fd':
            np.array([
                np.array([0.4]),
                np.array([0.4]),
                np.array([0.3]),
                np.array([0.2])
            ]),
            'label_key':
            np.array([
                np.array(['Label']),
                np.array(['Label']),
                np.array(['Label']),
                np.array(['Label'])
            ])
        }, {
            'fa':
            np.array([
                np.array(['Red']),
                np.array(['Blue']),
                np.array(['Blue']),
                np.array(['Green']),
                np.array(['Green'])
            ]),
            'fb':
            np.array([
                np.array([1.2]),
                np.array([0.5]),
                np.array([1.3]),
                np.array([2.3]),
                np.array([0.3])
            ]),
            'fc':
            np.array([
                np.array([1]),
                np.array([3, 2]),
                np.array([1, 4]),
                np.array([0]),
                np.array([3])
            ]),
            'fd':
            np.array([
                np.array([0.3]),
                np.array([0.4]),
                np.array([1.7]),
                np.array([np.NaN]),
                np.array([4.4])
            ]),
            'label_key':
            np.array([
                np.array(['Label']),
                np.array(['Label']),
                np.array(['Label']),
                np.array(['Label']),
                np.array(['Label'])
            ])
        }]

        schema = text_format.Parse(
            """
        feature {
          name: "fa"
          type: BYTES
          shape {
            dim {
              size: 1
            }
          }
        }
        feature {
          name: "fb"
          type: FLOAT
          shape {
            dim {
              size: 1
            }
          }
        }
        feature {
          name: "fc"
          type: INT
          value_count: {
            min: 0
            max: 2
          }
        }
        feature {
          name: "fd"
          type: FLOAT
          shape {
            dim {
              size: 1
            }
          }
        }
        feature {
          name: "label_key"
          type: BYTES
          shape {
            dim {
              size: 1
            }
          }
        }""", schema_pb2.Schema())

        expected_result = [
            text_format.Parse(
                """
              features {
                name: "fa"
                custom_stats {
                  name: "max_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "max_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "mean_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "mean_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "median_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "median_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "min_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "min_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "num_partitions_sklearn_adjusted_mutual_information"
                  num: 2.0
                }
                custom_stats {
                  name: "num_partitions_sklearn_mutual_information"
                  num: 2.0
                }
                custom_stats {
                  name: "std_dev_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "std_dev_sklearn_mutual_information"
                  num: 0.0
                }
              }
              features {
                name: "fb"
                custom_stats {
                  name: "max_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "max_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "mean_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "mean_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "median_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "median_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "min_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "min_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "num_partitions_sklearn_adjusted_mutual_information"
                  num: 2.0
                }
                custom_stats {
                  name: "num_partitions_sklearn_mutual_information"
                  num: 2.0
                }
                custom_stats {
                  name: "std_dev_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "std_dev_sklearn_mutual_information"
                  num: 0.0
                }
              }
              features {
                name: "fd"
                custom_stats {
                  name: "max_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "max_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "mean_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "mean_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "median_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "median_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "min_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "min_sklearn_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "num_partitions_sklearn_adjusted_mutual_information"
                  num: 2.0
                }
                custom_stats {
                  name: "num_partitions_sklearn_mutual_information"
                  num: 2.0
                }
                custom_stats {
                  name: "std_dev_sklearn_adjusted_mutual_information"
                  num: 0.0
                }
                custom_stats {
                  name: "std_dev_sklearn_mutual_information"
                  num: 0.0
                }
              }""", statistics_pb2.DatasetFeatureStatistics())
        ]
        generator = partitioned_stats_generator.NonStreamingCustomStatsGenerator(
            sklearn_mutual_information.SkLearnMutualInformation(
                label_feature='label_key', schema=schema, seed=TEST_SEED),
            num_partitions=2,
            min_partitions_stat_presence=2,
            seed=TEST_SEED,
            max_examples_per_partition=1000,
            name='NonStreaming Mutual Information')
        self.assertTransformOutputEqual(batches, generator, expected_result)