Ejemplo n.º 1
0
    def test_observe(self):
        # SPARK-36263: tests the DataFrame.observe(Observation, *Column) method
        from pyspark.sql import Observation

        df = SparkSession(self.sc).createDataFrame([
            (1, 1.0, 'one'),
            (2, 2.0, 'two'),
            (3, 3.0, 'three'),
        ], ['id', 'val', 'label'])

        unnamed_observation = Observation()
        named_observation = Observation("metric")
        observed = df.orderBy('id').observe(
            named_observation,
            count(lit(1)).alias('cnt'),
            sum(col("id")).alias('sum'),
            mean(col("val")).alias('mean')).observe(
                unnamed_observation,
                count(lit(1)).alias('rows'))

        # test that observe works transparently
        actual = observed.collect()
        self.assertEqual([
            {
                'id': 1,
                'val': 1.0,
                'label': 'one'
            },
            {
                'id': 2,
                'val': 2.0,
                'label': 'two'
            },
            {
                'id': 3,
                'val': 3.0,
                'label': 'three'
            },
        ], [row.asDict() for row in actual])

        # test that we retrieve the metrics
        self.assertEqual(named_observation.get, Row(cnt=3, sum=6, mean=2.0))
        self.assertEqual(unnamed_observation.get, Row(rows=3))

        # observation requires name (if given) to be non empty string
        with self.assertRaisesRegex(TypeError, 'name should be a string'):
            Observation(123)
        with self.assertRaisesRegex(ValueError, 'name should not be empty'):
            Observation('')

        # dataframe.observe requires at least one expr
        with self.assertRaisesRegex(AssertionError,
                                    'exprs should not be empty'):
            df.observe(Observation())

        # dataframe.observe requires non-None Columns
        for args in [(None, ), ('id', ), (lit(1), None), (lit(1), 'id')]:
            with self.subTest(args=args):
                with self.assertRaisesRegex(AssertionError,
                                            'all exprs should be Column'):
                    df.observe(Observation(), *args)