예제 #1
0
    def setUp(self):
        self.addCleanup(absltest.mock.patch.stopall)
        super(FeatureVisualizerTest, self).setUp()

        self.mock_bq_client = absltest.mock.create_autospec(
            bigquery.client.Client)
        self.features_table_path = 'project_id.dataset.features_table'
        self.numerical_features = ('num_feature1', 'num_feature2')
        self.categorical_features = ('cat_feature1', 'cat_feature2')
        self.label_column = 'predictionLabel',
        self.positive_class_label = True,
        self.negative_class_label = False,
        self.num_pos_instances = 10000,
        self.num_neg_instances = 10000

        self.feature_viz_obj = feature_visualizer.FeatureVisualizer(
            bq_client=self.mock_bq_client,
            features_table_path=self.features_table_path,
            numerical_features=self.numerical_features,
            categorical_features=self.categorical_features,
            label_column=self.label_column,
            positive_class_label=self.positive_class_label,
            negative_class_label=self.negative_class_label,
            num_pos_instances=self.num_pos_instances,
            num_neg_instances=self.num_neg_instances)

        self.mock_configure_sql = absltest.mock.patch.object(
            utils, 'read_file', autospec=True).start()

        self.mock_bq_client.query.return_value.to_dataframe.side_effect = [
            NUMERICAL_FEATURES_STATS, NUMERICAL_FEATURES_SAMPLE,
            CATEGORICAL_FEATURES_STATS
        ]
예제 #2
0
  def test_plot_features_returns_correct_plots_for_numeric_label(self):
    self.feature_viz_obj = feature_visualizer.FeatureVisualizer(
        bq_client=self.mock_bq_client,
        features_table_path=self.features_table_path,
        numerical_features=self.numerical_features,
        categorical_features=self.categorical_features,
        label_column=self.label_column,
        label_type='numerical',
        num_instances=10000)

    self.mock_configure_sql = absltest.mock.patch.object(
        utils, 'read_file', autospec=True).start()

    self.mock_bq_client.query.return_value.to_dataframe.side_effect = [
        NUMERICAL_LABEL_NUMERICAL_FEATURES_STATS,
        NUMERICAL_LABEL_NUMERICAL_FEATURES_SAMPLE,
        NUMERICAL_LABEL_CATEGORICAL_FEATURES_STATS,
        NUMERICAL_LABEL_STATS
    ]

    num_feature_1_plots, cat_feature_1_plots = (
        self.feature_viz_obj.plot_features())

    with self.subTest(name='test the number of plots returned'):
      self.assertLen(num_feature_1_plots, 2)
      self.assertLen(cat_feature_1_plots, 2)

    # Test elements in numerical fearture plots
    num_snapshot_dates = sorted(set(
        NUMERICAL_LABEL_NUMERICAL_FEATURES_STATS['snapshot_date']))

    with self.subTest(name='test label vs num_feature1 scatter plot'):
      self.assertEqual(
          'Scatter plot of the Label vs [num_feature1] |  correlation = 0.71',
          num_feature_1_plots[0].get_title())

    with self.subTest(
        name='test snapshopt distribution of num_feature1'):
      self.assertListEqual(
          num_snapshot_dates,
          sorted([
              tick.get_text()
              for tick in num_feature_1_plots[1].get_xticklabels()
          ]))
      self.assertEqual(
          'Snapshot-level distribution of [num_feature1]',
          num_feature_1_plots[1].get_title())

    # Test elements in categorical fearture plots

    cat_feature1_values = list(
        NUMERICAL_LABEL_STATS[
            NUMERICAL_LABEL_STATS['feature'] == 'cat_feature1']['value'])

    cat_snapshot_dates = sorted(set(
        NUMERICAL_LABEL_CATEGORICAL_FEATURES_STATS[
            NUMERICAL_LABEL_CATEGORICAL_FEATURES_STATS[
                'feature'] == 'cat_feature1']['snapshot_date']))

    with self.subTest(name='test the label distribution plot'):
      self.assertListEqual(
          cat_feature1_values,
          sorted([
              tick.get_text()
              for tick in cat_feature_1_plots[0].get_xticklabels()
          ]))
      self.assertEqual(
          'Label distribution by [cat_feature1] categories',
          cat_feature_1_plots[0].get_title())

    with self.subTest(
        name='test snapshopt distribution of cat_feature1'):
      self.assertListEqual(
          cat_snapshot_dates,
          sorted([
              tick.get_text()
              for tick in cat_feature_1_plots[1].get_xticklabels()
          ]))
      self.assertEqual(
          'Snapshot-level distribution of [cat_feature1]',
          cat_feature_1_plots[1].get_title())
예제 #3
0
  def test_plot_features_returns_correct_plots_for_binary_label(self):
    self.feature_viz_obj = feature_visualizer.FeatureVisualizer(
        bq_client=self.mock_bq_client,
        features_table_path=self.features_table_path,
        numerical_features=self.numerical_features,
        categorical_features=self.categorical_features,
        label_column=self.label_column,
        label_type='binary',
        positive_class_label='True',
        negative_class_label='False',
        num_pos_instances=10000,
        num_neg_instances=10000)

    self.mock_configure_sql = absltest.mock.patch.object(
        utils, 'read_file', autospec=True).start()

    self.mock_bq_client.query.return_value.to_dataframe.side_effect = [
        BINARY_LABEL_NUMERICAL_FEATURES_STATS,
        BINARY_LABEL_NUMERICAL_FEATURES_SAMPLE,
        BINARY_LABEL_CATEGORICAL_FEATURES_STATS
    ]

    label_values = ['False', 'True']

    num_feature1_data = BINARY_LABEL_NUMERICAL_FEATURES_STATS[
        BINARY_LABEL_NUMERICAL_FEATURES_STATS['feature'] == 'num_feature1']
    snapshot_dates_num_true = sorted(set(
        num_feature1_data[num_feature1_data['label'] == 'True']
        ['snapshot_date']))
    snapshot_dates_num_false = sorted(set(
        num_feature1_data[num_feature1_data['label'] == 'False']
        ['snapshot_date']))

    cat_feature1_data = BINARY_LABEL_CATEGORICAL_FEATURES_STATS[
        BINARY_LABEL_CATEGORICAL_FEATURES_STATS['feature'] == 'cat_feature1']
    cat_feature1_category_values1 = sorted(set(cat_feature1_data['value']))
    snapshot_dates_cat_true = sorted(set(
        cat_feature1_data[cat_feature1_data['label'] == 'True']
        ['snapshot_date']))
    snapshot_dates_cat_false = sorted(set(
        cat_feature1_data[cat_feature1_data['label'] == 'False']
        ['snapshot_date']))

    num_feature_1_plots, cat_feature_1_plots = (
        self.feature_viz_obj.plot_features())

    with self.subTest(name='test the number of plots returned'):
      self.assertLen(num_feature_1_plots, 3)
      self.assertLen(cat_feature_1_plots, 3)

    # Test elements in numerical fearture plots
    with self.subTest(name='test num_feature1 distribution box plot by label'):
      self.assertListEqual(label_values, [
          tick.get_text() for tick in num_feature_1_plots[0].get_yticklabels()
      ])
      self.assertEqual('Distribution of [num_feature1]',
                       num_feature_1_plots[0].get_title())

    with self.subTest(
        name='test snapshopt distribution of num_feature1 when label=True'):
      self.assertListEqual(
          snapshot_dates_num_true,
          sorted([
              tick.get_text()
              for tick in num_feature_1_plots[1].get_xticklabels()
          ]))
      self.assertEqual(
          'Snapshot-level distribution of [num_feature1] for label = True',
          num_feature_1_plots[1].get_title())

    with self.subTest(
        name='test snapshopt distribution of [num_feature1] when label=False'):
      self.assertListEqual(
          snapshot_dates_num_false,
          sorted([
              tick.get_text()
              for tick in num_feature_1_plots[2].get_xticklabels()
          ]))
      self.assertEqual(
          'Snapshot-level distribution of [num_feature1] for label = False',
          num_feature_1_plots[2].get_title())

    # Test elements in categorical fearture plots
    with self.subTest(name='test the cat_feature1 distribution plot'):
      self.assertListEqual(
          cat_feature1_category_values1,
          sorted([
              tick.get_text()
              for tick in cat_feature_1_plots[0].get_xticklabels()
          ]))
      self.assertEqual('Distribution of [cat_feature1]',
                       cat_feature_1_plots[0].get_title())

    with self.subTest(
        name='test snapshopt distribution of cat_feature1 when label=True'):
      self.assertListEqual(
          snapshot_dates_cat_true,
          sorted([
              tick.get_text()
              for tick in cat_feature_1_plots[1].get_xticklabels()
          ]))
      self.assertEqual(
          'Snapshot-level distribution of [cat_feature1] for label = True',
          cat_feature_1_plots[1].get_title())

    with self.subTest(
        name='test snapshopt distribution of [cat_feature1] when label=False'):
      self.assertListEqual(
          snapshot_dates_cat_false,
          sorted([
              tick.get_text()
              for tick in cat_feature_1_plots[2].get_xticklabels()
          ]))
      self.assertEqual(
          'Snapshot-level distribution of [cat_feature1] for label = False',
          cat_feature_1_plots[2].get_title())