示例#1
0
 def test_all_null_mask_all_null(self):
   batch = input_batch.InputBatch(
       pa.Table.from_arrays([
           pa.array([None, None], type=pa.null()),
           pa.array([None, None], type=pa.null())
       ], ['f1', 'f2']))
   path1 = types.FeaturePath(['f1'])
   path2 = types.FeaturePath(['f2'])
   expected_mask = np.array([True, True])
   np.testing.assert_array_equal(
       batch.all_null_mask(path1, path2), expected_mask)
示例#2
0
 def test_list_lengths_null_array(self):
   batch = input_batch.InputBatch(
       pa.Table.from_arrays([
           pa.array([None, None, None], type=pa.null()),
       ], ['f1']))
   np.testing.assert_array_equal(
       batch.list_lengths(types.FeaturePath(['f1'])), [0, 0, 0])
 def test_lift_null_y(self):
     examples = [
         pa.Table.from_arrays([
             pa.array([['a'], ['a'], ['b'], ['a']]),
             pa.array([None, None, None, None], type=pa.null()),
         ], ['categorical_x', 'string_y']),
     ]
     schema = text_format.Parse(
         """
     feature {
       name: 'categorical_x'
       type: BYTES
     }
     feature {
       name: 'string_y'
       type: BYTES
     }
     """, schema_pb2.Schema())
     expected_result = []
     generator = lift_stats_generator.LiftStatsGenerator(
         schema=schema, y_path=types.FeaturePath(['string_y']))
     self.assertSlicingAwareTransformOutputEqual(
         examples,
         generator,
         expected_result,
         add_default_slice_key_to_input=True,
         add_default_slice_key_to_output=True)
 def test_example_value_presence_null_array(self):
   t = pa.Table.from_arrays([
       pa.array([None, None], type=pa.null()),
   ], ['x'])
   self.assertIsNone(
       lift_stats_generator._get_example_value_presence(
           t, types.FeaturePath(['x']), boundaries=None))
def _GetEmptyTable(num_rows: int) -> pa.Table:
    # pyarrow doesn't provide an API to create a table with zero column but non
    # zero rows. We work around it by adding a dummy column first and then
    # removing it.
    t = pa.Table.from_arrays([pa.array([None] * num_rows, type=pa.null())],
                             ["dummy"])
    return t.remove_column(0)
示例#6
0
 def test_null_mask_null_array(self):
     batch = input_batch.InputBatch(
         pa.Table.from_arrays([pa.array([None], type=pa.null())],
                              ['feature']))
     path = types.FeaturePath(['feature'])
     expected_mask = np.array([True])
     np.testing.assert_array_equal(batch.null_mask(path), expected_mask)
示例#7
0
 def _process_column_infos(self, column_infos: List[csv_decoder.ColumnInfo]):
   column_handlers = []
   column_arrow_types = []
   for c in column_infos:
     if c.type == statistics_pb2.FeatureNameStatistics.INT:
       column_handlers.append(lambda v: (int(v),))
       column_arrow_types.append(pa.list_(pa.int64()))
     elif c.type == statistics_pb2.FeatureNameStatistics.FLOAT:
       column_handlers.append(lambda v: (float(v),))
       column_arrow_types.append(pa.list_(pa.float32()))
     elif c.type == statistics_pb2.FeatureNameStatistics.STRING:
       column_handlers.append(lambda v: (v,))
       column_arrow_types.append(pa.list_(pa.binary()))
     else:
       column_handlers.append(lambda _: None)
       column_arrow_types.append(pa.null())
   self._column_handlers = column_handlers
   self._column_arrow_types = column_arrow_types
   self._column_names = [c.name for c in column_infos]
class BinArrayTest(parameterized.TestCase):
    """Tests for bin_array."""
    @parameterized.named_parameters([
        ('simple', pa.array([0.1, 0.5, 0.75]), [0.25, 0.75], [0, 1,
                                                              2], [0, 1, 2]),
        ('negative_values', pa.array([-0.8, -0.5,
                                      -0.1]), [0.25], [0, 1, 2], [0, 0, 0]),
        ('inf_values', pa.array([float('-inf'), 0.5,
                                 float('inf')]), [0.25, 0.75], [0, 1,
                                                                2], [0, 1, 2]),
        ('nan_values', pa.array([np.nan, 0.5]), [0.25, 0.75], [1], [1]),
        ('negative_boundaries', pa.array([-0.8, -0.5]), [-0.75,
                                                         -0.25], [0,
                                                                  1], [0, 1]),
        ('empty_array', pa.array([]), [0.25], [], []),
        ('none_value', pa.array([None, 0.5]), [0.25], [1], [1]),
        ('null_array', pa.array([None, None], type=pa.null()), [0.25], [], [])
    ])
    def test_bin_array(self, array, boundaries, expected_indices,
                       expected_bins):
        indices, bins = bin_util.bin_array(array, boundaries)
        np.testing.assert_array_equal(expected_indices, indices)
        np.testing.assert_array_equal(expected_bins, bins)
 def test_lift_slice_aware(self):
     examples = [
         ('slice1',
          pa.Table.from_arrays([
              pa.array([['a'], ['a'], ['b'], ['a']]),
              pa.array([['cat'], ['dog'], ['cat'], ['dog']]),
          ], ['categorical_x', 'string_y'])),
         ('slice2',
          pa.Table.from_arrays([
              pa.array([['a'], ['a'], ['a']]),
              pa.array([['cat'], ['dog'], ['dog']]),
          ], ['categorical_x', 'string_y'])),
         ('slice1',
          pa.Table.from_arrays([
              pa.array([['a'], ['a'], ['b'], ['a']]),
              pa.array([['cat'], ['dog'], ['cat'], ['dog']]),
          ], ['categorical_x', 'string_y'])),
         ('slice2',
          pa.Table.from_arrays([
              pa.array([None, None, None, None], type=pa.null()),
              pa.array([['cat'], ['dog'], ['cat'], ['dog']]),
          ], ['categorical_x', 'string_y'])),
     ]
     schema = text_format.Parse(
         """
     feature {
       name: 'categorical_x'
       type: BYTES
     }
     feature {
       name: 'string_y'
       type: BYTES
     }
     """, schema_pb2.Schema())
     expected_result = [
         ('slice1',
          text_format.Parse(
              """
         cross_features {
           path_x {
             step: "categorical_x"
           }
           path_y {
             step: "string_y"
           }
           categorical_cross_stats {
             lift_series {
               y_string: "cat"
               y_count: 4
               lift_values {
                 x_string: "b"
                 lift: 2.0
                 x_count: 2
                 x_and_y_count: 2
               }
               lift_values {
                 x_string: "a"
                 lift: 0.666666984558
                 x_count: 6
                 x_and_y_count: 2
               }
             }
             lift_series {
               y_string: "dog"
               y_count: 4
               lift_values {
                 x_string: "a"
                 lift: 1.33333301544
                 x_count: 6
                 x_and_y_count: 4
               }
               lift_values {
                 x_string: "b"
                 lift: 0.0
                 x_count: 2
                 x_and_y_count: 0
               }
             }
          }
         }""", statistics_pb2.DatasetFeatureStatistics())),
         ('slice2',
          text_format.Parse(
              """
         cross_features {
           path_x {
             step: "categorical_x"
           }
           path_y {
             step: "string_y"
           }
           categorical_cross_stats {
             lift_series {
               y_string: "cat"
               y_count: 3
               lift_values {
                 x_string: "a"
                 lift: 0.777778029441
                 x_count: 3
                 x_and_y_count: 1
               }
             }
             lift_series {
               y_string: "dog"
               y_count: 4
               lift_values {
                 x_string: "a"
                 lift: 1.16666698455
                 x_count: 3
                 x_and_y_count: 2
               }
             }
          }
         }""", statistics_pb2.DatasetFeatureStatistics())),
     ]
     generator = lift_stats_generator.LiftStatsGenerator(
         schema=schema, y_path=types.FeaturePath(['string_y']))
     self.assertSlicingAwareTransformOutputEqual(examples, generator,
                                                 expected_result)
示例#10
0
    def test_topk_uniques_with_categorical_feature(self):
        examples = [
            pa.Table.from_arrays(
                [pa.array([[12, 23, 34, 12], [45, 23], [12, 12, 34, 45]])],
                ['fa']),
            pa.Table.from_arrays([pa.array([None, None], type=pa.null())],
                                 ['fa'])
        ]

        expected_result = [
            text_format.Parse(
                """
      features {
        path {
          step: 'fa'
        }
        type: INT
        string_stats {
          top_values {
            value: '12'
            frequency: 4
          }
          top_values {
            value: '45'
            frequency: 2
          }
          rank_histogram {
            buckets {
              low_rank: 0
              high_rank: 0
              label: "12"
              sample_count: 4.0
            }
            buckets {
              low_rank: 1
              high_rank: 1
              label: "45"
              sample_count: 2.0
            }
            buckets {
              low_rank: 2
              high_rank: 2
              label: "34"
              sample_count: 2.0
            }
          }
        }
    }""", statistics_pb2.DatasetFeatureStatistics()),
            text_format.Parse(
                """
    features {
        path {
          step: 'fa'
        }
        type: INT
        string_stats {
          unique: 4
        }
      }""", statistics_pb2.DatasetFeatureStatistics()),
        ]

        schema = text_format.Parse(
            """
        feature {
          name: "fa"
          type: INT
          int_domain {
            is_categorical: true
          }
        }
        """, schema_pb2.Schema())
        generator = top_k_uniques_stats_generator.TopKUniquesStatsGenerator(
            schema=schema, num_top_values=2, num_rank_histogram_buckets=3)
        self.assertSlicingAwareTransformOutputEqual(
            examples,
            generator,
            expected_result,
            add_default_slice_key_to_input=True,
            add_default_slice_key_to_output=True)
示例#11
0
             pa.array([None, [1., 2., 3.], None, None],
                      pa.list_(pa.float32())),
             "f3":
             pa.array([None, None, [b"abc", b"def"], None],
                      pa.list_(pa.binary())),
             "f4":
             pa.array([None, None, None, [8]], pa.list_(pa.int64())),
         }),
    dict(testcase_name="null_array",
         input_examples=[{
             "a": None,
         }, {
             "a": None,
         }],
         expected_output={
             "a": pa.array([None, None], type=pa.null()),
         })
]


class DecodedExamplesToArrowPyTest(parameterized.TestCase):
    @parameterized.named_parameters(*_INVALID_INPUT_TEST_CASES)
    def test_invalid_input(self, test_input, expected_error,
                           expected_error_regexp):
        with self.assertRaisesRegex(expected_error, expected_error_regexp):
            decoded_examples_to_arrow.DecodedExamplesToTablePy(test_input)

    @parameterized.named_parameters(*_CONVERSION_TEST_CASES)
    def test_conversion(self, input_examples, expected_output):
        table = decoded_examples_to_arrow.DecodedExamplesToTablePy(
            input_examples)
 def test_sparse_feature_generator_multiple_sparse_features(self):
     batches = [
         pa.Table.from_arrays([
             pa.array([
                 None, None, ['a', 'b'], ['a', 'b'], ['a', 'b'], None, None
             ]),
             pa.array([[1, 2], [1, 2], None, None, None, None, None]),
             pa.array([[2, 4], [2, 4], [2, 4, 6], [2, 4, 6], [2, 4, 6],
                       None, None]),
             pa.array(
                 [None, None, None, None, None, ['a', 'b'], ['a', 'b']]),
             pa.array([None, None, None, None, None, [2, 4], [2, 4]]),
             pa.array([None, None, None, None, None, None, None],
                      type=pa.null()),
         ], [
             'value_feature', 'index_feature1', 'index_feature2',
             'other_value_feature', 'other_index_feature1',
             'other_index_feature2'
         ]),
         pa.Table.from_arrays([
             pa.array(
                 [None, None, None, None, None, ['a', 'b'], ['a', 'b']]),
             pa.array([None, None, None, None, None, [2, 4], [2, 4]]),
             pa.array([None, None, None, None, None, None, None],
                      type=pa.null())
         ], [
             'other_value_feature', 'other_index_feature1',
             'other_index_feature2'
         ]),
     ]
     schema = text_format.Parse(
         """
     sparse_feature {
       name: 'sparse_feature'
       index_feature {
         name: 'index_feature1'
       }
       index_feature {
         name: 'index_feature2'
       }
       value_feature {
         name: 'value_feature'
       }
     }
     sparse_feature {
       name: 'other_sparse_feature'
       index_feature {
         name: 'other_index_feature1'
       }
       index_feature {
         name: 'other_index_feature2'
       }
       value_feature {
         name: 'other_value_feature'
       }
     }
     """, schema_pb2.Schema())
     expected_result = {
         types.FeaturePath(['sparse_feature']):
         text_format.Parse(
             """
             path {
               step: 'sparse_feature'
             }
             custom_stats {
               name: 'missing_value'
               num: 2
             }
             custom_stats {
               name: 'missing_index'
               rank_histogram {
                 buckets {
                   label: 'index_feature1'
                   sample_count: 3
                 }
                 buckets {
                   label: 'index_feature2'
                   sample_count: 0
                 }
               }
             }
             custom_stats {
               name: 'max_length_diff'
               rank_histogram {
                 buckets {
                   label: 'index_feature1'
                   sample_count: 2
                 }
                 buckets {
                   label: 'index_feature2'
                   sample_count: 2
                 }
               }
             }
             custom_stats {
               name: 'min_length_diff'
               rank_histogram {
                 buckets {
                   label: 'index_feature1'
                   sample_count: -2
                 }
                 buckets {
                   label: 'index_feature2'
                   sample_count: 1
                 }
               }
             }""", statistics_pb2.FeatureNameStatistics()),
         types.FeaturePath(['other_sparse_feature']):
         text_format.Parse(
             """
             path {
               step: 'other_sparse_feature'
             }
             custom_stats {
               name: 'missing_value'
               num: 0
             }
             custom_stats {
               name: 'missing_index'
               rank_histogram {
                 buckets {
                   label: 'other_index_feature1'
                   sample_count: 0
                 }
                 buckets {
                   label: 'other_index_feature2'
                   sample_count: 4
                 }
               }
             }
             custom_stats {
               name: 'max_length_diff'
               rank_histogram {
                 buckets {
                   label: 'other_index_feature1'
                   sample_count: 0
                 }
                 buckets {
                   label: 'other_index_feature2'
                   sample_count: -2
                 }
               }
             }
             custom_stats {
               name: 'min_length_diff'
               rank_histogram {
                 buckets {
                   label: 'other_index_feature1'
                   sample_count: 0
                 }
                 buckets {
                   label: 'other_index_feature2'
                   sample_count: -2
                 }
               }
             }""", statistics_pb2.FeatureNameStatistics())
     }
     generator = (
         sparse_feature_stats_generator.SparseFeatureStatsGenerator(schema))
     self.assertCombinerOutputEqual(batches, generator, expected_result)
 def test_basic_stats_generator_handle_null_column(self):
     # Feature 'a' covers null coming before non-null.
     # Feature 'b' covers null coming after non-null.
     b1 = pa.Table.from_arrays([
         pa.array([None, None, None], type=pa.null()),
         pa.array([[1.0, 2.0, 3.0], [4.0], [5.0]]),
     ], ['a', 'b'])
     b2 = pa.Table.from_arrays([
         pa.array([[1, 2], None], type=pa.list_(pa.int64())),
         pa.array([None, None], type=pa.null()),
     ], ['a', 'b'])
     batches = [b1, b2]
     expected_result = {
         types.FeaturePath(['a']):
         text_format.Parse(
             """
         path {
           step: "a"
         }
         num_stats {
           common_stats {
             num_non_missing: 1
             min_num_values: 2
             max_num_values: 2
             avg_num_values: 2.0
             num_values_histogram {
               buckets {
                 low_value: 2.0
                 high_value: 2.0
                 sample_count: 0.25
               }
               buckets {
                 low_value: 2.0
                 high_value: 2.0
                 sample_count: 0.25
               }
               buckets {
                 low_value: 2.0
                 high_value: 2.0
                 sample_count: 0.25
               }
               buckets {
                 low_value: 2.0
                 high_value: 2.0
                 sample_count: 0.25
               }
               type: QUANTILES
             }
             tot_num_values: 2
           }
           mean: 1.5
           std_dev: 0.5
           min: 1.0
           median: 2.0
           max: 2.0
           histograms {
             buckets {
               low_value: 1.0
               high_value: 1.3333333
               sample_count: 0.9955556
             }
             buckets {
               low_value: 1.3333333
               high_value: 1.6666667
               sample_count: 0.0022222
             }
             buckets {
               low_value: 1.6666667
               high_value: 2.0
               sample_count: 1.0022222
             }
           }
           histograms {
             buckets {
               low_value: 1.0
               high_value: 1.0
               sample_count: 0.5
             }
             buckets {
               low_value: 1.0
               high_value: 2.0
               sample_count: 0.5
             }
             buckets {
               low_value: 2.0
               high_value: 2.0
               sample_count: 0.5
             }
             buckets {
               low_value: 2.0
               high_value: 2.0
               sample_count: 0.5
             }
             type: QUANTILES
           }
         }
         """, statistics_pb2.FeatureNameStatistics()),
         types.FeaturePath(['b']):
         text_format.Parse(
             """
         path {
           step: 'b'
         }
         type: FLOAT
         num_stats {
           common_stats {
             num_non_missing: 3
             min_num_values: 1
             max_num_values: 3
             avg_num_values: 1.66666698456
             num_values_histogram {
               buckets {
                 low_value: 1.0
                 high_value: 1.0
                 sample_count: 0.75
               }
               buckets {
                 low_value: 1.0
                 high_value: 1.0
                 sample_count: 0.75
               }
               buckets {
                 low_value: 1.0
                 high_value: 3.0
                 sample_count: 0.75
               }
               buckets {
                 low_value: 3.0
                 high_value: 3.0
                 sample_count: 0.75
               }
               type: QUANTILES
             }
             tot_num_values: 5
           }
           mean: 3.0
           std_dev: 1.4142136
           min: 1.0
           median: 3.0
           max: 5.0
           histograms {
             buckets {
               low_value: 1.0
               high_value: 2.3333333
               sample_count: 1.9888889
             }
             buckets {
               low_value: 2.3333333
               high_value: 3.6666667
               sample_count: 1.0055556
             }
             buckets {
               low_value: 3.6666667
               high_value: 5.0
               sample_count: 2.0055556
             }
           }
           histograms {
             buckets {
               low_value: 1.0
               high_value: 2.0
               sample_count: 1.25
             }
             buckets {
               low_value: 2.0
               high_value: 3.0
               sample_count: 1.25
             }
             buckets {
               low_value: 3.0
               high_value: 4.0
               sample_count: 1.25
             }
             buckets {
               low_value: 4.0
               high_value: 5.0
               sample_count: 1.25
             }
             type: QUANTILES
           }
         }
         """, statistics_pb2.FeatureNameStatistics()),
     }
     generator = basic_stats_generator.BasicStatsGenerator(
         num_values_histogram_buckets=4,
         num_histogram_buckets=3,
         num_quantiles_histogram_buckets=4)
     self.assertCombinerOutputEqual(batches, generator, expected_result)
示例#14
0
 def test_all_null_mask_no_paths(self):
   batch = input_batch.InputBatch(
       pa.Table.from_arrays([pa.array([None, [1]], type=pa.null())], ['f3']))
   with self.assertRaisesRegex(ValueError, r'Paths cannot be empty.*'):
     batch.all_null_mask()
 def test_topk_uniques_combiner_with_categorical_feature(self):
   # fa: 4 12, 2 23, 2 34, 2 45
   batches = [
       pa.Table.from_arrays([pa.array([[12, 23, 34, 12], [45, 23]])], ['fa']),
       pa.Table.from_arrays([pa.array([[12, 12, 34, 45]])], ['fa']),
       pa.Table.from_arrays(
           [pa.array([None, None, None, None], type=pa.null())], ['fa']),
   ]
   expected_result = {
       types.FeaturePath(['fa']):
           text_format.Parse(
               """
               path {
                 step: 'fa'
               }
               type: INT
               string_stats {
                 unique: 4
                 top_values {
                   value: '12'
                   frequency: 4
                 }
                 top_values {
                   value: '45'
                   frequency: 2
                 }
                 top_values {
                   value: '34'
                   frequency: 2
                 }
                 top_values {
                   value: '23'
                   frequency: 2
                 }
                 rank_histogram {
                   buckets {
                     low_rank: 0
                     high_rank: 0
                     label: "12"
                     sample_count: 4.0
                   }
                   buckets {
                     low_rank: 1
                     high_rank: 1
                     label: "45"
                     sample_count: 2.0
                   }
                   buckets {
                     low_rank: 2
                     high_rank: 2
                     label: "34"
                     sample_count: 2.0
                   }
                 }
             }""", statistics_pb2.FeatureNameStatistics())
   }
   schema = text_format.Parse(
       """
       feature {
         name: "fa"
         type: INT
         int_domain {
           is_categorical: true
         }
       }
       """, schema_pb2.Schema())
   generator = (
       top_k_uniques_combiner_stats_generator
       .TopKUniquesCombinerStatsGenerator(
           schema=schema, num_top_values=4, num_rank_histogram_buckets=3))
   self.assertCombinerOutputEqual(batches, generator, expected_result)
class WeightedFeatureStatsGeneratorTest(parameterized.TestCase,
                                        test_util.CombinerStatsGeneratorTest):
    @parameterized.named_parameters(
        {
            'testcase_name':
            'AllMatching',
            'batches': [
                pa.Table.from_arrays(
                    [pa.array([['a'], ['a', 'b']]),
                     pa.array([[2], [2, 4]])], ['value', 'weight'])
            ],
            'expected_missing_weight':
            0.0,
            'expected_missing_value':
            0.0,
            'expected_min_weight_length_diff':
            0.0,
            'expected_max_weight_length_diff':
            0.0
        }, {
            'testcase_name':
            'AllMatchingMultiBatch',
            'batches': [
                pa.Table.from_arrays(
                    [pa.array([['a'], ['a', 'b']]),
                     pa.array([[2], [2, 4]])], ['value', 'weight']),
                pa.Table.from_arrays(
                    [pa.array([['a'], ['a', 'b']]),
                     pa.array([[2], [2, 4]])], ['value', 'weight'])
            ],
            'expected_missing_weight':
            0.0,
            'expected_missing_value':
            0.0,
            'expected_min_weight_length_diff':
            0.0,
            'expected_max_weight_length_diff':
            0.0
        }, {
            'testcase_name':
            'LengthMismatchPositive',
            'batches': [
                pa.Table.from_arrays(
                    [pa.array([['a'], ['a']]),
                     pa.array([[2], [2, 4]])], ['value', 'weight'])
            ],
            'expected_missing_weight':
            0.0,
            'expected_missing_value':
            0.0,
            'expected_min_weight_length_diff':
            0.0,
            'expected_max_weight_length_diff':
            1.0
        }, {
            'testcase_name':
            'LengthMismatchNegative',
            'batches': [
                pa.Table.from_arrays(
                    [pa.array([['a'], ['a', 'b']]),
                     pa.array([[2], [2]])], ['value', 'weight'])
            ],
            'expected_missing_weight':
            0.0,
            'expected_missing_value':
            0.0,
            'expected_min_weight_length_diff':
            -1.0,
            'expected_max_weight_length_diff':
            0.0
        }, {
            'testcase_name':
            'LengthMismatchMultiBatch',
            'batches': [
                pa.Table.from_arrays(
                    [pa.array([['a'], ['a', 'b']]),
                     pa.array([[], []])], ['value', 'weight']),
                pa.Table.from_arrays([pa.array([[1], [1, 1]])], ['other'])
            ],
            'expected_missing_weight':
            0.0,
            'expected_missing_value':
            0.0,
            'expected_min_weight_length_diff':
            -2.0,
            'expected_max_weight_length_diff':
            -1.0
        }, {
            'testcase_name':
            'SomePairsMissing',
            'batches': [
                pa.Table.from_arrays([
                    pa.array([['a'], None, ['a', 'b']]),
                    pa.array([[1, 1], None, [1, 1, 1]])
                ], ['value', 'weight'])
            ],
            'expected_missing_weight':
            0.0,
            'expected_missing_value':
            0.0,
            'expected_min_weight_length_diff':
            1.0,
            'expected_max_weight_length_diff':
            1.0
        }, {
            'testcase_name':
            'EmptyWeights',
            'batches':
            [pa.Table.from_arrays([pa.array([['a'], ['a', 'b']])], ['value'])],
            'expected_missing_weight':
            2.0,
            'expected_missing_value':
            0.0,
            'expected_min_weight_length_diff':
            -2.0,
            'expected_max_weight_length_diff':
            -1.0
        }, {
            'testcase_name':
            'EmptyValues',
            'batches':
            [pa.Table.from_arrays([pa.array([[1], [1, 2]])], ['weight'])],
            'expected_missing_weight':
            0.0,
            'expected_missing_value':
            2.0,
            'expected_min_weight_length_diff':
            1.0,
            'expected_max_weight_length_diff':
            2.0
        }, {
            'testcase_name': 'EmptyWeightsAndValues',
            'batches': [pa.Table.from_arrays([])],
            'expected_missing_weight': 0.0,
            'expected_missing_value': 0.0,
            'expected_min_weight_length_diff': 0.0,
            'expected_max_weight_length_diff': 0.0
        }, {
            'testcase_name':
            'NullWeightArray',
            'batches': [
                pa.Table.from_arrays([
                    pa.array([['a'], ['a', 'b']]),
                    pa.array([None, None], type=pa.null())
                ], ['value', 'weight'])
            ],
            'expected_missing_weight':
            2.0,
            'expected_missing_value':
            0.0,
            'expected_min_weight_length_diff':
            -2.0,
            'expected_max_weight_length_diff':
            -1.0
        })
    def test_single_weighted_feature(self, batches, expected_missing_weight,
                                     expected_missing_value,
                                     expected_min_weight_length_diff,
                                     expected_max_weight_length_diff):
        schema = text_format.Parse(
            """
        weighted_feature {
          name: 'weighted_feature'
          feature {
            step: 'value'
          }
          weight_feature {
            step: 'weight'
          }
        }
        """, schema_pb2.Schema())
        generator = (weighted_feature_stats_generator.
                     WeightedFeatureStatsGenerator(schema))

        expected_stats = statistics_pb2.FeatureNameStatistics()
        expected_stats.path.step.append('weighted_feature')
        expected_stats.custom_stats.add(name='missing_weight',
                                        num=expected_missing_weight)
        expected_stats.custom_stats.add(name='missing_value',
                                        num=expected_missing_value)
        expected_stats.custom_stats.add(name='min_weight_length_diff',
                                        num=expected_min_weight_length_diff)
        expected_stats.custom_stats.add(name='max_weight_length_diff',
                                        num=expected_max_weight_length_diff)
        expected_result = {
            types.FeaturePath(['weighted_feature']): expected_stats
        }

        self.assertCombinerOutputEqual(batches, generator, expected_result)

    def test_shared_weight(self):
        batches = [
            pa.Table.from_arrays([
                pa.array([['a'], ['a', 'b'], ['a']]),
                pa.array([['x'], ['y'], ['x']]),
                pa.array([[2], [4], None])
            ], ['value1', 'value2', 'weight'])
        ]
        schema = text_format.Parse(
            """
        weighted_feature {
          name: 'weighted_feature1'
          feature {
            step: 'value1'
          }
          weight_feature {
            step: 'weight'
          }
        }
        weighted_feature {
          name: 'weighted_feature2'
          feature {
            step: 'value2'
          }
          weight_feature {
            step: 'weight'
          }
        }""", schema_pb2.Schema())
        generator = (weighted_feature_stats_generator.
                     WeightedFeatureStatsGenerator(schema))

        expected_result = {
            types.FeaturePath(['weighted_feature1']):
            text_format.Parse(
                """
                path {
                  step: 'weighted_feature1'
                }
                custom_stats {
                  name: 'missing_weight'
                  num: 1.0
                }
                custom_stats {
                  name: 'missing_value'
                  num: 0.0
                }
                custom_stats {
                  name: 'min_weight_length_diff'
                  num: -1.0
                }
                custom_stats {
                  name: 'max_weight_length_diff'
                  num: 0.0
                }""", statistics_pb2.FeatureNameStatistics()),
            types.FeaturePath(['weighted_feature2']):
            text_format.Parse(
                """
                path {
                  step: 'weighted_feature2'
                }
                custom_stats {
                  name: 'missing_weight'
                  num: 1.0
                }
                custom_stats {
                  name: 'missing_value'
                  num: 0.0
                }
                custom_stats {
                  name: 'min_weight_length_diff'
                  num: -1.0
                }
                custom_stats {
                  name: 'max_weight_length_diff'
                  num: 0.0
                }""", statistics_pb2.FeatureNameStatistics())
        }

        self.assertCombinerOutputEqual(batches, generator, expected_result)
    def test_mi_regression_with_null_array(self):
        label_array = pa.array([[0.1], [0.2], [0.8], [0.7], [0.2], [0.3],
                                [0.9], [0.4], [0.1], [0.0], [0.4], [0.6],
                                [0.4], [0.8]])
        # Random floats that do not map onto the label
        terrible_feat_array = pa.array([[0.4], [0.1], [0.4], [0.4], [0.8],
                                        [0.7], [0.2], [0.1], [0.0], [0.4],
                                        [0.8], [0.2], [0.5], [0.1]])
        null_array = pa.array([None] * 14, type=pa.null())
        # Note: It is possible to get different results for py2 and py3, depending
        # on the feature name used (e.g., if use 'empty_feature', the results
        # differ). This might be due to the scikit learn function used to compute MI
        # adding a small amount of noise to continuous features before computing MI.
        batch = pa.Table.from_arrays(
            [label_array, label_array, terrible_feat_array, null_array], [
                "label_key", "perfect_feature", "terrible_feature",
                "values_empty_feature"
            ])

        schema = text_format.Parse(
            """
        feature {
          name: "values_empty_feature"
          type: FLOAT
          shape {
            dim {
              size: 1
            }
          }
        }
        feature {
          name: "perfect_feature"
          type: FLOAT
          shape {
            dim {
              size: 1
            }
          }
        }
        feature {
          name: "terrible_feature"
          type: FLOAT
          shape {
            dim {
              size: 1
            }
          }
        }
        feature {
          name: "label_key"
          type: FLOAT
          shape {
            dim {
              size: 1
            }
          }
        }
        """, schema_pb2.Schema())

        expected = text_format.Parse(
            """
        features {
          path {
            step: "perfect_feature"
          }
          custom_stats {
            name: "sklearn_adjusted_mutual_information"
            num: 1.0742656
          }
          custom_stats {
            name: "sklearn_mutual_information"
            num: 1.2277528
          }
        }
        features {
          path {
            step: "terrible_feature"
          }
          custom_stats {
            name: "sklearn_adjusted_mutual_information"
            num: 0.0392891
          }
          custom_stats {
            name: "sklearn_mutual_information"
            num: 0.0392891
          }
        }
        features {
          path {
            step: "values_empty_feature"
          }
          custom_stats {
            name: "sklearn_adjusted_mutual_information"
            num: 0.0
          }
          custom_stats {
            name: "sklearn_mutual_information"
            num: 0.0
          }
        }""", statistics_pb2.DatasetFeatureStatistics())
        self._assert_mi_output_equal(batch, expected, schema,
                                     types.FeaturePath(["label_key"]))